ugmSorcero commited on
Commit
158f4dc
0 Parent(s):

Adds files from huggingface hub repo

Browse files
Files changed (8) hide show
  1. .gitattributes +31 -0
  2. .gitignore +5 -0
  3. README.md +12 -0
  4. app.py +88 -0
  5. dataset.py +27 -0
  6. model.py +191 -0
  7. requirements.txt +11 -0
  8. train.py +110 -0
.gitattributes ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ftz filter=lfs diff=lfs merge=lfs -text
6
+ *.gz filter=lfs diff=lfs merge=lfs -text
7
+ *.h5 filter=lfs diff=lfs merge=lfs -text
8
+ *.joblib filter=lfs diff=lfs merge=lfs -text
9
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
10
+ *.model filter=lfs diff=lfs merge=lfs -text
11
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
12
+ *.npy filter=lfs diff=lfs merge=lfs -text
13
+ *.npz filter=lfs diff=lfs merge=lfs -text
14
+ *.onnx filter=lfs diff=lfs merge=lfs -text
15
+ *.ot filter=lfs diff=lfs merge=lfs -text
16
+ *.parquet filter=lfs diff=lfs merge=lfs -text
17
+ *.pickle filter=lfs diff=lfs merge=lfs -text
18
+ *.pkl filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pt filter=lfs diff=lfs merge=lfs -text
21
+ *.pth filter=lfs diff=lfs merge=lfs -text
22
+ *.rar filter=lfs diff=lfs merge=lfs -text
23
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
24
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
25
+ *.tflite filter=lfs diff=lfs merge=lfs -text
26
+ *.tgz filter=lfs diff=lfs merge=lfs -text
27
+ *.wasm filter=lfs diff=lfs merge=lfs -text
28
+ *.xz filter=lfs diff=lfs merge=lfs -text
29
+ *.zip filter=lfs diff=lfs merge=lfs -text
30
+ *.zst filter=lfs diff=lfs merge=lfs -text
31
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
1
+ feedback*
2
+ new_model/
3
+ __pycache__/
4
+ data/
5
+ events.out.*
README.md ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Grocery Classifier Demo
3
+ emoji: 🛒
4
+ colorFrom: red
5
+ colorTo: green
6
+ sdk: streamlit
7
+ sdk_version: 1.10.0
8
+ app_file: app.py
9
+ pinned: false
10
+ ---
11
+
12
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import streamlit as st
3
+ from PIL import Image
4
+ import requests
5
+ import io
6
+ import time
7
+ from model import ViTForImageClassification
8
+
9
+ st.set_page_config(
10
+ page_title="Grocery Classifier",
11
+ page_icon="interface/shopping-cart.png",
12
+ initial_sidebar_state="expanded"
13
+ )
14
+
15
+ @st.cache()
16
+ def load_model():
17
+ with st.spinner("Loading model"):
18
+ model = ViTForImageClassification('google/vit-base-patch16-224')
19
+ model.load('model/')
20
+ return model
21
+
22
+ model = load_model()
23
+ feedback_path = "feedback"
24
+
25
+ def predict(image):
26
+ print("Predicting...")
27
+ # Load using PIL
28
+ image = Image.open(image)
29
+
30
+ prediction, confidence = model.predict(image)
31
+
32
+ return {'prediction': prediction[0], 'confidence': round(confidence[0], 3)}, image
33
+
34
+ def submit_feedback(correct_label, image):
35
+ folder_path = feedback_path + "/" + correct_label + "/"
36
+ os.makedirs(folder_path, exist_ok=True)
37
+ image.save(folder_path + correct_label + "_" + str(int(time.time())) + ".png")
38
+
39
+ def retrain_from_feedback():
40
+ model.retrain_from_path(feedback_path, remove_path=True)
41
+
42
+ def main():
43
+ labels = set(list(model.label_encoder.classes_))
44
+
45
+ st.title("🍇 Grocery Classifier 🥑")
46
+
47
+ if labels is None:
48
+ st.warning("Received error from server, labels could not be retrieved")
49
+ else:
50
+ st.write("Labels:", labels)
51
+
52
+ image_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
53
+ if image_file is not None:
54
+ st.image(image_file)
55
+
56
+ st.subheader("Classification")
57
+
58
+ if st.button("Predict"):
59
+ st.session_state['response_json'], st.session_state['image'] = predict(image_file)
60
+
61
+ if 'response_json' in st.session_state and st.session_state['response_json'] is not None:
62
+ # Show the result
63
+ st.markdown(f"**Prediction:** {st.session_state['response_json']['prediction']}")
64
+ st.markdown(f"**Confidence:** {st.session_state['response_json']['confidence']}")
65
+
66
+ # User feedback
67
+ st.subheader("User Feedback")
68
+ st.markdown("If this prediction was incorrect, please select below the correct label")
69
+ correct_labels = labels.copy()
70
+ correct_labels.remove(st.session_state['response_json']["prediction"])
71
+ correct_label = st.selectbox("Correct label", correct_labels)
72
+ if st.button("Submit"):
73
+ # Save feedback
74
+ try:
75
+ submit_feedback(correct_label, st.session_state['image'])
76
+ st.success("Feedback submitted")
77
+ except Exception as e:
78
+ st.error("Feedback could not be submitted. Error: {}".format(e))
79
+
80
+ # Retrain from feedback
81
+ if st.button("Retrain from feedback"):
82
+ try:
83
+ retrain_from_feedback()
84
+ st.success("Model retrained")
85
+ except Exception as e:
86
+ st.warning("Model could not be retrained. Error: {}".format(e))
87
+
88
+ main()
dataset.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ class RetailDataset(torch.utils.data.Dataset):
4
+ def __init__(self, data, labels=None, transform=None):
5
+ self.data = data
6
+ self.labels = labels
7
+ self.num_classes = len(set(labels))
8
+ self.transform = transform
9
+
10
+ def __getitem__(self, idx):
11
+ item = {key: val[idx].detach().clone() for key, val in self.data.items()}
12
+ item['labels'] = self.labels[idx]
13
+ return item
14
+
15
+ def __len__(self):
16
+ return len(self.labels)
17
+
18
+ def __repr__(self):
19
+ return 'RetailDataset'
20
+
21
+ def __str__(self):
22
+ return str({
23
+ 'data': self.data['pixel_values'].shape,
24
+ 'labels': self.labels.shape,
25
+ 'num_classes': self.num_classes,
26
+ 'num_samples': len(self.labels)
27
+ })
model.py ADDED
@@ -0,0 +1,191 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import shutil
2
+ import time
3
+ import numpy as np
4
+ from tqdm import tqdm
5
+ from transformers import ViTModel, ViTFeatureExtractor
6
+ from transformers.modeling_outputs import SequenceClassifierOutput
7
+ import torch.nn as nn
8
+ import torch
9
+ from PIL import Image
10
+ import logging
11
+ import os
12
+ from sklearn.preprocessing import LabelEncoder
13
+ from train import (
14
+ re_training, metric, f1_score,
15
+ classification_report
16
+ )
17
+
18
+ data_path = os.environ.get('DATA_PATH', "./data")
19
+
20
+ logging.basicConfig(level=os.getenv("LOGGER_LEVEL", logging.WARNING))
21
+ logger = logging.getLogger(__name__)
22
+
23
+ class ViTForImageClassification(nn.Module):
24
+ def __init__(self, model_name, num_labels=24, dropout=0.25, image_size=224):
25
+ logger.info("Loading model")
26
+ super(ViTForImageClassification, self).__init__()
27
+ self.vit = ViTModel.from_pretrained(model_name)
28
+ self.feature_extractor = ViTFeatureExtractor.from_pretrained(model_name)
29
+ self.feature_extractor.do_resize = True
30
+ self.feature_extractor.size = image_size
31
+ self.dropout = nn.Dropout(dropout)
32
+ self.classifier = nn.Linear(self.vit.config.hidden_size, num_labels)
33
+ self.num_labels = num_labels
34
+ self.label_encoder = LabelEncoder()
35
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
36
+ self.model_name = model_name
37
+ # To device
38
+ self.vit.to(self.device)
39
+ self.to(self.device)
40
+ self.classifier.to(self.device)
41
+ logger.info("Model loaded")
42
+
43
+ def forward(self, pixel_values, labels):
44
+ logger.info("Forwarding")
45
+ pixel_values = pixel_values.to(self.device)
46
+ outputs = self.vit(pixel_values=pixel_values)
47
+ output = self.dropout(outputs.last_hidden_state[:,0])
48
+ logits = self.classifier(output)
49
+
50
+ loss = None
51
+ if labels is not None:
52
+ loss_fct = nn.CrossEntropyLoss()
53
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
54
+
55
+ return SequenceClassifierOutput(
56
+ loss=loss,
57
+ logits=logits,
58
+ hidden_states=outputs.hidden_states,
59
+ attentions=outputs.attentions,
60
+ )
61
+
62
+ def preprocess_image(self, images):
63
+ logger.info("Preprocessing images")
64
+ return self.feature_extractor(images, return_tensors='pt')
65
+
66
+ def predict(self, images, batch_size=32, classes_names=True, return_probabilities=False):
67
+ logger.info("Predicting")
68
+ if not isinstance(images, list):
69
+ images = [images]
70
+ classes_list = []
71
+ confidence_list = []
72
+ for bs in tqdm(range(0, len(images), batch_size), desc="Preprocessing training images"):
73
+ images_batch = [image for image in images[bs:bs+batch_size]]
74
+ images_batch = self.preprocess_image(images_batch)['pixel_values']
75
+ sequence_classifier_output = self.forward(images_batch, None)
76
+ # Get max prob
77
+ probs = sequence_classifier_output.logits.softmax(dim=-1).tolist()
78
+ classes = np.argmax(probs, axis=1)
79
+ confidences = np.max(probs, axis=1)
80
+ classes_list.extend(classes)
81
+ confidence_list.extend(confidences)
82
+ if classes_names:
83
+ classes_list = self.label_encoder.inverse_transform(classes_list)
84
+ if return_probabilities:
85
+ return classes_list, confidence_list, probs
86
+ return classes_list, confidence_list
87
+
88
+ def save(self, path):
89
+ logger.info("Saving model")
90
+ os.makedirs(path, exist_ok=True)
91
+ torch.save(self.state_dict(), path + "/model.pt")
92
+ # Save label encoder
93
+ np.save(path + "/label_encoder.npy", self.label_encoder.classes_)
94
+
95
+ def load(self, path):
96
+ logger.info("Loading model")
97
+ # Load label encoder
98
+ # Check if label encoder and model exists
99
+ if not os.path.exists(path + "/label_encoder.npy") or not os.path.exists(path + "/model.pt"):
100
+ logger.warning("Label encoder or model not found")
101
+ return
102
+ self.label_encoder.classes_ = np.load(path + "/label_encoder.npy")
103
+ # Reload classifier layer
104
+ self.classifier = nn.Linear(self.vit.config.hidden_size, len(self.label_encoder.classes_))
105
+
106
+ self.load_state_dict(torch.load(path + "/model.pt", map_location=self.device))
107
+ self.vit.to(self.device)
108
+ self.vit.eval()
109
+ self.to(self.device)
110
+ self.eval()
111
+
112
+ def evaluate(self, images, labels):
113
+ logger.info("Evaluating")
114
+ labels = self.label_encoder.transform(labels)
115
+ # Predict
116
+ y_pred, _ = self.predict(images, classes_names=False)
117
+ # Evaluate
118
+ metrics = metric.compute(predictions=y_pred, references=labels)
119
+ f1 = f1_score.compute(predictions=y_pred, references=labels, average="macro")
120
+ print(classification_report(labels, y_pred, labels=[i for i in range(len(self.label_encoder.classes_))], target_names=self.label_encoder.classes_))
121
+ print(f"Accuracy: {metrics['accuracy']}")
122
+ print(f"F1: {f1}")
123
+
124
+ def partial_fit(self, images, labels, save_model_path='new_model', num_epochs=10):
125
+ logger.info("Partial fitting")
126
+ # Freeze ViT model but last layer
127
+ # params = [param for param in self.vit.parameters()]
128
+ # for param in params[:-1]:
129
+ # param.requires_grad = False
130
+ # Model in training mode
131
+ self.vit.train()
132
+ self.train()
133
+ re_training(images, labels, self, save_model_path, num_epochs)
134
+ self.load(save_model_path)
135
+ self.vit.eval()
136
+ self.eval()
137
+ self.evaluate(images, labels)
138
+
139
+ def __load_from_path(self, path, num_per_label=None):
140
+ images = []
141
+ labels = []
142
+ for label in os.listdir(path):
143
+ count = 0
144
+ label_folder_path = os.path.join(path, label)
145
+ for image_file in tqdm(os.listdir(label_folder_path), desc="Resizing images for label {}".format(label)):
146
+ file_path = os.path.join(label_folder_path, image_file)
147
+ try:
148
+ image = Image.open(file_path)
149
+ image_shape = (self.feature_extractor.size, self.feature_extractor.size)
150
+ if image.size != image_shape:
151
+ image = image.resize(image_shape)
152
+ images.append(image.convert('RGB'))
153
+ labels.append(label)
154
+ count += 1
155
+ except Exception as e:
156
+ print(f"ERROR - Could not resize image {file_path} - {e}")
157
+ if num_per_label is not None and count >= num_per_label:
158
+ break
159
+ return images, labels
160
+
161
+ def retrain_from_path(self,
162
+ path='./data/feedback',
163
+ num_per_label=None,
164
+ save_model_path='new_model',
165
+ remove_path=False,
166
+ num_epochs=10,
167
+ save_new_data=data_path + '/new_data'):
168
+ logger.info("Retraining from path")
169
+ # Load path
170
+ images, labels = self.__load_from_path(path, num_per_label)
171
+ # Retrain
172
+ self.partial_fit(images, labels, save_model_path, num_epochs)
173
+ # Save new data
174
+ if save_new_data is not None:
175
+ logger.info("Saving new data")
176
+ for i ,(image, label) in enumerate(zip(images, labels)):
177
+ label_path = os.path.join(save_new_data, label)
178
+ os.makedirs(label_path, exist_ok=True)
179
+ image.save(os.path.join(label_path, str(int(time.time())) + f"_{i}.jpg"))
180
+ # Remove path folder
181
+ if remove_path:
182
+ logger.info("Removing feedback path")
183
+ shutil.rmtree(path)
184
+
185
+ def evaluate_from_path(self, path, num_per_label=None):
186
+ logger.info("Evaluating from path")
187
+ # Load images
188
+ images, labels = self.__load_from_path(path, num_per_label)
189
+ # Evaluate
190
+ self.evaluate(images, labels)
191
+
requirements.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ Pillow
2
+ requests
3
+ numpy
4
+ transformers
5
+ scikit-learn
6
+ datasets
7
+ streamlit
8
+ matplotlib
9
+ scikit-image
10
+ torch
11
+ torchvision
train.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ from sklearn.metrics import classification_report
4
+ from tqdm import tqdm
5
+ import logging
6
+ from sklearn.model_selection import train_test_split
7
+ from dataset import RetailDataset
8
+ from PIL import Image
9
+ from datasets import load_metric
10
+ from torchvision.transforms import (
11
+ CenterCrop,
12
+ Compose,
13
+ Normalize,
14
+ RandomHorizontalFlip,
15
+ RandomResizedCrop,
16
+ Resize,
17
+ ToTensor,
18
+ )
19
+ from transformers import Trainer, TrainingArguments, BatchFeature
20
+ metric = load_metric("accuracy")
21
+ f1_score = load_metric("f1")
22
+ np.random.seed(42)
23
+
24
+ logging.basicConfig(level=os.getenv("LOGGER_LEVEL", logging.WARNING))
25
+ logger = logging.getLogger(__name__)
26
+
27
+ def prepare_dataset(images,
28
+ labels,
29
+ model,
30
+ test_size=.2,
31
+ train_transform=None,
32
+ val_transform=None,
33
+ batch_size=512):
34
+ logger.info("Preparing dataset")
35
+ # Split the dataset in train and test
36
+ try:
37
+ images_train, images_test, labels_train, labels_test = \
38
+ train_test_split(images, labels, test_size=test_size)
39
+ except ValueError:
40
+ logger.warning("Could not split dataset. Using all data for training and testing")
41
+ images_train = images
42
+ labels_train = labels
43
+ images_test = images
44
+ labels_test = labels
45
+
46
+ # Preprocess images using model feature extractor
47
+ images_train_prep = []
48
+ images_test_prep = []
49
+ for bs in tqdm(range(0, len(images_train), batch_size), desc="Preprocessing training images"):
50
+ images_train_batch = [Image.fromarray(np.array(image)) for image in images_train[bs:bs+batch_size]]
51
+ images_train_batch = model.preprocess_image(images_train_batch)
52
+ images_train_prep.extend(images_train_batch['pixel_values'])
53
+ for bs in tqdm(range(0, len(images_test), batch_size), desc="Preprocessing test images"):
54
+ images_test_batch = [Image.fromarray(np.array(image)) for image in images_test[bs:bs+batch_size]]
55
+ images_test_batch = model.preprocess_image(images_test_batch)
56
+ images_test_prep.extend(images_test_batch['pixel_values'])
57
+
58
+ # Create BatchFeatures
59
+ images_train_prep = {"pixel_values": images_train_prep}
60
+ train_batch_features = BatchFeature(data=images_train_prep)
61
+ images_test_prep = {"pixel_values": images_test_prep}
62
+ test_batch_features = BatchFeature(data=images_test_prep)
63
+
64
+ # Create the datasets
65
+ train_dataset = RetailDataset(train_batch_features, labels_train, train_transform)
66
+ test_dataset = RetailDataset(test_batch_features, labels_test, val_transform)
67
+ logger.info("Train dataset: %d images", len(labels_train))
68
+ logger.info("Test dataset: %d images", len(labels_test))
69
+ return train_dataset, test_dataset
70
+
71
+ def re_training(images, labels, _model, save_model_path='new_model', num_epochs=10):
72
+ global model
73
+ model = _model
74
+ labels = model.label_encoder.transform(labels)
75
+ normalize = Normalize(mean=model.feature_extractor.image_mean, std=model.feature_extractor.image_std)
76
+ def train_transforms(batch):
77
+ return Compose([
78
+ RandomResizedCrop(model.feature_extractor.size),
79
+ RandomHorizontalFlip(),
80
+ ToTensor(),
81
+ normalize,
82
+ ])(batch)
83
+
84
+ def val_transforms(batch):
85
+ return Compose([
86
+ Resize(model.feature_extractor.size),
87
+ CenterCrop(model.feature_extractor.size),
88
+ ToTensor(),
89
+ normalize,
90
+ ])(batch)
91
+ train_dataset, test_dataset = prepare_dataset(
92
+ images, labels, model, .2, train_transforms, val_transforms)
93
+ trainer = Trainer(
94
+ model=model,
95
+ args=TrainingArguments(
96
+ output_dir='output',
97
+ overwrite_output_dir=True,
98
+ num_train_epochs=num_epochs,
99
+ per_device_train_batch_size=32,
100
+ gradient_accumulation_steps=1,
101
+ learning_rate=0.000001,
102
+ weight_decay=0.01,
103
+ evaluation_strategy='steps',
104
+ eval_steps=1000,
105
+ save_steps=3000),
106
+ train_dataset=train_dataset,
107
+ eval_dataset=test_dataset
108
+ )
109
+ trainer.train()
110
+ model.save(save_model_path)