Spaces:
Runtime error
Runtime error
ugmSorcero
commited on
Commit
•
158f4dc
0
Parent(s):
Adds files from huggingface hub repo
Browse files- .gitattributes +31 -0
- .gitignore +5 -0
- README.md +12 -0
- app.py +88 -0
- dataset.py +27 -0
- model.py +191 -0
- requirements.txt +11 -0
- train.py +110 -0
.gitattributes
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
6 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
7 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
8 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
9 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
10 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
11 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
12 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
13 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
14 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
15 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
16 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
17 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
18 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
21 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
22 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
23 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
24 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
25 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
26 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
27 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
28 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
29 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
30 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
31 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
feedback*
|
2 |
+
new_model/
|
3 |
+
__pycache__/
|
4 |
+
data/
|
5 |
+
events.out.*
|
README.md
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
title: Grocery Classifier Demo
|
3 |
+
emoji: 🛒
|
4 |
+
colorFrom: red
|
5 |
+
colorTo: green
|
6 |
+
sdk: streamlit
|
7 |
+
sdk_version: 1.10.0
|
8 |
+
app_file: app.py
|
9 |
+
pinned: false
|
10 |
+
---
|
11 |
+
|
12 |
+
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
app.py
ADDED
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import streamlit as st
|
3 |
+
from PIL import Image
|
4 |
+
import requests
|
5 |
+
import io
|
6 |
+
import time
|
7 |
+
from model import ViTForImageClassification
|
8 |
+
|
9 |
+
st.set_page_config(
|
10 |
+
page_title="Grocery Classifier",
|
11 |
+
page_icon="interface/shopping-cart.png",
|
12 |
+
initial_sidebar_state="expanded"
|
13 |
+
)
|
14 |
+
|
15 |
+
@st.cache()
|
16 |
+
def load_model():
|
17 |
+
with st.spinner("Loading model"):
|
18 |
+
model = ViTForImageClassification('google/vit-base-patch16-224')
|
19 |
+
model.load('model/')
|
20 |
+
return model
|
21 |
+
|
22 |
+
model = load_model()
|
23 |
+
feedback_path = "feedback"
|
24 |
+
|
25 |
+
def predict(image):
|
26 |
+
print("Predicting...")
|
27 |
+
# Load using PIL
|
28 |
+
image = Image.open(image)
|
29 |
+
|
30 |
+
prediction, confidence = model.predict(image)
|
31 |
+
|
32 |
+
return {'prediction': prediction[0], 'confidence': round(confidence[0], 3)}, image
|
33 |
+
|
34 |
+
def submit_feedback(correct_label, image):
|
35 |
+
folder_path = feedback_path + "/" + correct_label + "/"
|
36 |
+
os.makedirs(folder_path, exist_ok=True)
|
37 |
+
image.save(folder_path + correct_label + "_" + str(int(time.time())) + ".png")
|
38 |
+
|
39 |
+
def retrain_from_feedback():
|
40 |
+
model.retrain_from_path(feedback_path, remove_path=True)
|
41 |
+
|
42 |
+
def main():
|
43 |
+
labels = set(list(model.label_encoder.classes_))
|
44 |
+
|
45 |
+
st.title("🍇 Grocery Classifier 🥑")
|
46 |
+
|
47 |
+
if labels is None:
|
48 |
+
st.warning("Received error from server, labels could not be retrieved")
|
49 |
+
else:
|
50 |
+
st.write("Labels:", labels)
|
51 |
+
|
52 |
+
image_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
|
53 |
+
if image_file is not None:
|
54 |
+
st.image(image_file)
|
55 |
+
|
56 |
+
st.subheader("Classification")
|
57 |
+
|
58 |
+
if st.button("Predict"):
|
59 |
+
st.session_state['response_json'], st.session_state['image'] = predict(image_file)
|
60 |
+
|
61 |
+
if 'response_json' in st.session_state and st.session_state['response_json'] is not None:
|
62 |
+
# Show the result
|
63 |
+
st.markdown(f"**Prediction:** {st.session_state['response_json']['prediction']}")
|
64 |
+
st.markdown(f"**Confidence:** {st.session_state['response_json']['confidence']}")
|
65 |
+
|
66 |
+
# User feedback
|
67 |
+
st.subheader("User Feedback")
|
68 |
+
st.markdown("If this prediction was incorrect, please select below the correct label")
|
69 |
+
correct_labels = labels.copy()
|
70 |
+
correct_labels.remove(st.session_state['response_json']["prediction"])
|
71 |
+
correct_label = st.selectbox("Correct label", correct_labels)
|
72 |
+
if st.button("Submit"):
|
73 |
+
# Save feedback
|
74 |
+
try:
|
75 |
+
submit_feedback(correct_label, st.session_state['image'])
|
76 |
+
st.success("Feedback submitted")
|
77 |
+
except Exception as e:
|
78 |
+
st.error("Feedback could not be submitted. Error: {}".format(e))
|
79 |
+
|
80 |
+
# Retrain from feedback
|
81 |
+
if st.button("Retrain from feedback"):
|
82 |
+
try:
|
83 |
+
retrain_from_feedback()
|
84 |
+
st.success("Model retrained")
|
85 |
+
except Exception as e:
|
86 |
+
st.warning("Model could not be retrained. Error: {}".format(e))
|
87 |
+
|
88 |
+
main()
|
dataset.py
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
class RetailDataset(torch.utils.data.Dataset):
|
4 |
+
def __init__(self, data, labels=None, transform=None):
|
5 |
+
self.data = data
|
6 |
+
self.labels = labels
|
7 |
+
self.num_classes = len(set(labels))
|
8 |
+
self.transform = transform
|
9 |
+
|
10 |
+
def __getitem__(self, idx):
|
11 |
+
item = {key: val[idx].detach().clone() for key, val in self.data.items()}
|
12 |
+
item['labels'] = self.labels[idx]
|
13 |
+
return item
|
14 |
+
|
15 |
+
def __len__(self):
|
16 |
+
return len(self.labels)
|
17 |
+
|
18 |
+
def __repr__(self):
|
19 |
+
return 'RetailDataset'
|
20 |
+
|
21 |
+
def __str__(self):
|
22 |
+
return str({
|
23 |
+
'data': self.data['pixel_values'].shape,
|
24 |
+
'labels': self.labels.shape,
|
25 |
+
'num_classes': self.num_classes,
|
26 |
+
'num_samples': len(self.labels)
|
27 |
+
})
|
model.py
ADDED
@@ -0,0 +1,191 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import shutil
|
2 |
+
import time
|
3 |
+
import numpy as np
|
4 |
+
from tqdm import tqdm
|
5 |
+
from transformers import ViTModel, ViTFeatureExtractor
|
6 |
+
from transformers.modeling_outputs import SequenceClassifierOutput
|
7 |
+
import torch.nn as nn
|
8 |
+
import torch
|
9 |
+
from PIL import Image
|
10 |
+
import logging
|
11 |
+
import os
|
12 |
+
from sklearn.preprocessing import LabelEncoder
|
13 |
+
from train import (
|
14 |
+
re_training, metric, f1_score,
|
15 |
+
classification_report
|
16 |
+
)
|
17 |
+
|
18 |
+
data_path = os.environ.get('DATA_PATH', "./data")
|
19 |
+
|
20 |
+
logging.basicConfig(level=os.getenv("LOGGER_LEVEL", logging.WARNING))
|
21 |
+
logger = logging.getLogger(__name__)
|
22 |
+
|
23 |
+
class ViTForImageClassification(nn.Module):
|
24 |
+
def __init__(self, model_name, num_labels=24, dropout=0.25, image_size=224):
|
25 |
+
logger.info("Loading model")
|
26 |
+
super(ViTForImageClassification, self).__init__()
|
27 |
+
self.vit = ViTModel.from_pretrained(model_name)
|
28 |
+
self.feature_extractor = ViTFeatureExtractor.from_pretrained(model_name)
|
29 |
+
self.feature_extractor.do_resize = True
|
30 |
+
self.feature_extractor.size = image_size
|
31 |
+
self.dropout = nn.Dropout(dropout)
|
32 |
+
self.classifier = nn.Linear(self.vit.config.hidden_size, num_labels)
|
33 |
+
self.num_labels = num_labels
|
34 |
+
self.label_encoder = LabelEncoder()
|
35 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
36 |
+
self.model_name = model_name
|
37 |
+
# To device
|
38 |
+
self.vit.to(self.device)
|
39 |
+
self.to(self.device)
|
40 |
+
self.classifier.to(self.device)
|
41 |
+
logger.info("Model loaded")
|
42 |
+
|
43 |
+
def forward(self, pixel_values, labels):
|
44 |
+
logger.info("Forwarding")
|
45 |
+
pixel_values = pixel_values.to(self.device)
|
46 |
+
outputs = self.vit(pixel_values=pixel_values)
|
47 |
+
output = self.dropout(outputs.last_hidden_state[:,0])
|
48 |
+
logits = self.classifier(output)
|
49 |
+
|
50 |
+
loss = None
|
51 |
+
if labels is not None:
|
52 |
+
loss_fct = nn.CrossEntropyLoss()
|
53 |
+
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
54 |
+
|
55 |
+
return SequenceClassifierOutput(
|
56 |
+
loss=loss,
|
57 |
+
logits=logits,
|
58 |
+
hidden_states=outputs.hidden_states,
|
59 |
+
attentions=outputs.attentions,
|
60 |
+
)
|
61 |
+
|
62 |
+
def preprocess_image(self, images):
|
63 |
+
logger.info("Preprocessing images")
|
64 |
+
return self.feature_extractor(images, return_tensors='pt')
|
65 |
+
|
66 |
+
def predict(self, images, batch_size=32, classes_names=True, return_probabilities=False):
|
67 |
+
logger.info("Predicting")
|
68 |
+
if not isinstance(images, list):
|
69 |
+
images = [images]
|
70 |
+
classes_list = []
|
71 |
+
confidence_list = []
|
72 |
+
for bs in tqdm(range(0, len(images), batch_size), desc="Preprocessing training images"):
|
73 |
+
images_batch = [image for image in images[bs:bs+batch_size]]
|
74 |
+
images_batch = self.preprocess_image(images_batch)['pixel_values']
|
75 |
+
sequence_classifier_output = self.forward(images_batch, None)
|
76 |
+
# Get max prob
|
77 |
+
probs = sequence_classifier_output.logits.softmax(dim=-1).tolist()
|
78 |
+
classes = np.argmax(probs, axis=1)
|
79 |
+
confidences = np.max(probs, axis=1)
|
80 |
+
classes_list.extend(classes)
|
81 |
+
confidence_list.extend(confidences)
|
82 |
+
if classes_names:
|
83 |
+
classes_list = self.label_encoder.inverse_transform(classes_list)
|
84 |
+
if return_probabilities:
|
85 |
+
return classes_list, confidence_list, probs
|
86 |
+
return classes_list, confidence_list
|
87 |
+
|
88 |
+
def save(self, path):
|
89 |
+
logger.info("Saving model")
|
90 |
+
os.makedirs(path, exist_ok=True)
|
91 |
+
torch.save(self.state_dict(), path + "/model.pt")
|
92 |
+
# Save label encoder
|
93 |
+
np.save(path + "/label_encoder.npy", self.label_encoder.classes_)
|
94 |
+
|
95 |
+
def load(self, path):
|
96 |
+
logger.info("Loading model")
|
97 |
+
# Load label encoder
|
98 |
+
# Check if label encoder and model exists
|
99 |
+
if not os.path.exists(path + "/label_encoder.npy") or not os.path.exists(path + "/model.pt"):
|
100 |
+
logger.warning("Label encoder or model not found")
|
101 |
+
return
|
102 |
+
self.label_encoder.classes_ = np.load(path + "/label_encoder.npy")
|
103 |
+
# Reload classifier layer
|
104 |
+
self.classifier = nn.Linear(self.vit.config.hidden_size, len(self.label_encoder.classes_))
|
105 |
+
|
106 |
+
self.load_state_dict(torch.load(path + "/model.pt", map_location=self.device))
|
107 |
+
self.vit.to(self.device)
|
108 |
+
self.vit.eval()
|
109 |
+
self.to(self.device)
|
110 |
+
self.eval()
|
111 |
+
|
112 |
+
def evaluate(self, images, labels):
|
113 |
+
logger.info("Evaluating")
|
114 |
+
labels = self.label_encoder.transform(labels)
|
115 |
+
# Predict
|
116 |
+
y_pred, _ = self.predict(images, classes_names=False)
|
117 |
+
# Evaluate
|
118 |
+
metrics = metric.compute(predictions=y_pred, references=labels)
|
119 |
+
f1 = f1_score.compute(predictions=y_pred, references=labels, average="macro")
|
120 |
+
print(classification_report(labels, y_pred, labels=[i for i in range(len(self.label_encoder.classes_))], target_names=self.label_encoder.classes_))
|
121 |
+
print(f"Accuracy: {metrics['accuracy']}")
|
122 |
+
print(f"F1: {f1}")
|
123 |
+
|
124 |
+
def partial_fit(self, images, labels, save_model_path='new_model', num_epochs=10):
|
125 |
+
logger.info("Partial fitting")
|
126 |
+
# Freeze ViT model but last layer
|
127 |
+
# params = [param for param in self.vit.parameters()]
|
128 |
+
# for param in params[:-1]:
|
129 |
+
# param.requires_grad = False
|
130 |
+
# Model in training mode
|
131 |
+
self.vit.train()
|
132 |
+
self.train()
|
133 |
+
re_training(images, labels, self, save_model_path, num_epochs)
|
134 |
+
self.load(save_model_path)
|
135 |
+
self.vit.eval()
|
136 |
+
self.eval()
|
137 |
+
self.evaluate(images, labels)
|
138 |
+
|
139 |
+
def __load_from_path(self, path, num_per_label=None):
|
140 |
+
images = []
|
141 |
+
labels = []
|
142 |
+
for label in os.listdir(path):
|
143 |
+
count = 0
|
144 |
+
label_folder_path = os.path.join(path, label)
|
145 |
+
for image_file in tqdm(os.listdir(label_folder_path), desc="Resizing images for label {}".format(label)):
|
146 |
+
file_path = os.path.join(label_folder_path, image_file)
|
147 |
+
try:
|
148 |
+
image = Image.open(file_path)
|
149 |
+
image_shape = (self.feature_extractor.size, self.feature_extractor.size)
|
150 |
+
if image.size != image_shape:
|
151 |
+
image = image.resize(image_shape)
|
152 |
+
images.append(image.convert('RGB'))
|
153 |
+
labels.append(label)
|
154 |
+
count += 1
|
155 |
+
except Exception as e:
|
156 |
+
print(f"ERROR - Could not resize image {file_path} - {e}")
|
157 |
+
if num_per_label is not None and count >= num_per_label:
|
158 |
+
break
|
159 |
+
return images, labels
|
160 |
+
|
161 |
+
def retrain_from_path(self,
|
162 |
+
path='./data/feedback',
|
163 |
+
num_per_label=None,
|
164 |
+
save_model_path='new_model',
|
165 |
+
remove_path=False,
|
166 |
+
num_epochs=10,
|
167 |
+
save_new_data=data_path + '/new_data'):
|
168 |
+
logger.info("Retraining from path")
|
169 |
+
# Load path
|
170 |
+
images, labels = self.__load_from_path(path, num_per_label)
|
171 |
+
# Retrain
|
172 |
+
self.partial_fit(images, labels, save_model_path, num_epochs)
|
173 |
+
# Save new data
|
174 |
+
if save_new_data is not None:
|
175 |
+
logger.info("Saving new data")
|
176 |
+
for i ,(image, label) in enumerate(zip(images, labels)):
|
177 |
+
label_path = os.path.join(save_new_data, label)
|
178 |
+
os.makedirs(label_path, exist_ok=True)
|
179 |
+
image.save(os.path.join(label_path, str(int(time.time())) + f"_{i}.jpg"))
|
180 |
+
# Remove path folder
|
181 |
+
if remove_path:
|
182 |
+
logger.info("Removing feedback path")
|
183 |
+
shutil.rmtree(path)
|
184 |
+
|
185 |
+
def evaluate_from_path(self, path, num_per_label=None):
|
186 |
+
logger.info("Evaluating from path")
|
187 |
+
# Load images
|
188 |
+
images, labels = self.__load_from_path(path, num_per_label)
|
189 |
+
# Evaluate
|
190 |
+
self.evaluate(images, labels)
|
191 |
+
|
requirements.txt
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Pillow
|
2 |
+
requests
|
3 |
+
numpy
|
4 |
+
transformers
|
5 |
+
scikit-learn
|
6 |
+
datasets
|
7 |
+
streamlit
|
8 |
+
matplotlib
|
9 |
+
scikit-image
|
10 |
+
torch
|
11 |
+
torchvision
|
train.py
ADDED
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import numpy as np
|
3 |
+
from sklearn.metrics import classification_report
|
4 |
+
from tqdm import tqdm
|
5 |
+
import logging
|
6 |
+
from sklearn.model_selection import train_test_split
|
7 |
+
from dataset import RetailDataset
|
8 |
+
from PIL import Image
|
9 |
+
from datasets import load_metric
|
10 |
+
from torchvision.transforms import (
|
11 |
+
CenterCrop,
|
12 |
+
Compose,
|
13 |
+
Normalize,
|
14 |
+
RandomHorizontalFlip,
|
15 |
+
RandomResizedCrop,
|
16 |
+
Resize,
|
17 |
+
ToTensor,
|
18 |
+
)
|
19 |
+
from transformers import Trainer, TrainingArguments, BatchFeature
|
20 |
+
metric = load_metric("accuracy")
|
21 |
+
f1_score = load_metric("f1")
|
22 |
+
np.random.seed(42)
|
23 |
+
|
24 |
+
logging.basicConfig(level=os.getenv("LOGGER_LEVEL", logging.WARNING))
|
25 |
+
logger = logging.getLogger(__name__)
|
26 |
+
|
27 |
+
def prepare_dataset(images,
|
28 |
+
labels,
|
29 |
+
model,
|
30 |
+
test_size=.2,
|
31 |
+
train_transform=None,
|
32 |
+
val_transform=None,
|
33 |
+
batch_size=512):
|
34 |
+
logger.info("Preparing dataset")
|
35 |
+
# Split the dataset in train and test
|
36 |
+
try:
|
37 |
+
images_train, images_test, labels_train, labels_test = \
|
38 |
+
train_test_split(images, labels, test_size=test_size)
|
39 |
+
except ValueError:
|
40 |
+
logger.warning("Could not split dataset. Using all data for training and testing")
|
41 |
+
images_train = images
|
42 |
+
labels_train = labels
|
43 |
+
images_test = images
|
44 |
+
labels_test = labels
|
45 |
+
|
46 |
+
# Preprocess images using model feature extractor
|
47 |
+
images_train_prep = []
|
48 |
+
images_test_prep = []
|
49 |
+
for bs in tqdm(range(0, len(images_train), batch_size), desc="Preprocessing training images"):
|
50 |
+
images_train_batch = [Image.fromarray(np.array(image)) for image in images_train[bs:bs+batch_size]]
|
51 |
+
images_train_batch = model.preprocess_image(images_train_batch)
|
52 |
+
images_train_prep.extend(images_train_batch['pixel_values'])
|
53 |
+
for bs in tqdm(range(0, len(images_test), batch_size), desc="Preprocessing test images"):
|
54 |
+
images_test_batch = [Image.fromarray(np.array(image)) for image in images_test[bs:bs+batch_size]]
|
55 |
+
images_test_batch = model.preprocess_image(images_test_batch)
|
56 |
+
images_test_prep.extend(images_test_batch['pixel_values'])
|
57 |
+
|
58 |
+
# Create BatchFeatures
|
59 |
+
images_train_prep = {"pixel_values": images_train_prep}
|
60 |
+
train_batch_features = BatchFeature(data=images_train_prep)
|
61 |
+
images_test_prep = {"pixel_values": images_test_prep}
|
62 |
+
test_batch_features = BatchFeature(data=images_test_prep)
|
63 |
+
|
64 |
+
# Create the datasets
|
65 |
+
train_dataset = RetailDataset(train_batch_features, labels_train, train_transform)
|
66 |
+
test_dataset = RetailDataset(test_batch_features, labels_test, val_transform)
|
67 |
+
logger.info("Train dataset: %d images", len(labels_train))
|
68 |
+
logger.info("Test dataset: %d images", len(labels_test))
|
69 |
+
return train_dataset, test_dataset
|
70 |
+
|
71 |
+
def re_training(images, labels, _model, save_model_path='new_model', num_epochs=10):
|
72 |
+
global model
|
73 |
+
model = _model
|
74 |
+
labels = model.label_encoder.transform(labels)
|
75 |
+
normalize = Normalize(mean=model.feature_extractor.image_mean, std=model.feature_extractor.image_std)
|
76 |
+
def train_transforms(batch):
|
77 |
+
return Compose([
|
78 |
+
RandomResizedCrop(model.feature_extractor.size),
|
79 |
+
RandomHorizontalFlip(),
|
80 |
+
ToTensor(),
|
81 |
+
normalize,
|
82 |
+
])(batch)
|
83 |
+
|
84 |
+
def val_transforms(batch):
|
85 |
+
return Compose([
|
86 |
+
Resize(model.feature_extractor.size),
|
87 |
+
CenterCrop(model.feature_extractor.size),
|
88 |
+
ToTensor(),
|
89 |
+
normalize,
|
90 |
+
])(batch)
|
91 |
+
train_dataset, test_dataset = prepare_dataset(
|
92 |
+
images, labels, model, .2, train_transforms, val_transforms)
|
93 |
+
trainer = Trainer(
|
94 |
+
model=model,
|
95 |
+
args=TrainingArguments(
|
96 |
+
output_dir='output',
|
97 |
+
overwrite_output_dir=True,
|
98 |
+
num_train_epochs=num_epochs,
|
99 |
+
per_device_train_batch_size=32,
|
100 |
+
gradient_accumulation_steps=1,
|
101 |
+
learning_rate=0.000001,
|
102 |
+
weight_decay=0.01,
|
103 |
+
evaluation_strategy='steps',
|
104 |
+
eval_steps=1000,
|
105 |
+
save_steps=3000),
|
106 |
+
train_dataset=train_dataset,
|
107 |
+
eval_dataset=test_dataset
|
108 |
+
)
|
109 |
+
trainer.train()
|
110 |
+
model.save(save_model_path)
|