ugmSorcero
Adds files from huggingface hub repo
158f4dc
import shutil
import time
import numpy as np
from tqdm import tqdm
from transformers import ViTModel, ViTFeatureExtractor
from transformers.modeling_outputs import SequenceClassifierOutput
import torch.nn as nn
import torch
from PIL import Image
import logging
import os
from sklearn.preprocessing import LabelEncoder
from train import (
re_training, metric, f1_score,
classification_report
)
data_path = os.environ.get('DATA_PATH', "./data")
logging.basicConfig(level=os.getenv("LOGGER_LEVEL", logging.WARNING))
logger = logging.getLogger(__name__)
class ViTForImageClassification(nn.Module):
def __init__(self, model_name, num_labels=24, dropout=0.25, image_size=224):
logger.info("Loading model")
super(ViTForImageClassification, self).__init__()
self.vit = ViTModel.from_pretrained(model_name)
self.feature_extractor = ViTFeatureExtractor.from_pretrained(model_name)
self.feature_extractor.do_resize = True
self.feature_extractor.size = image_size
self.dropout = nn.Dropout(dropout)
self.classifier = nn.Linear(self.vit.config.hidden_size, num_labels)
self.num_labels = num_labels
self.label_encoder = LabelEncoder()
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.model_name = model_name
# To device
self.vit.to(self.device)
self.to(self.device)
self.classifier.to(self.device)
logger.info("Model loaded")
def forward(self, pixel_values, labels):
logger.info("Forwarding")
pixel_values = pixel_values.to(self.device)
outputs = self.vit(pixel_values=pixel_values)
output = self.dropout(outputs.last_hidden_state[:,0])
logits = self.classifier(output)
loss = None
if labels is not None:
loss_fct = nn.CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
return SequenceClassifierOutput(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
def preprocess_image(self, images):
logger.info("Preprocessing images")
return self.feature_extractor(images, return_tensors='pt')
def predict(self, images, batch_size=32, classes_names=True, return_probabilities=False):
logger.info("Predicting")
if not isinstance(images, list):
images = [images]
classes_list = []
confidence_list = []
for bs in tqdm(range(0, len(images), batch_size), desc="Preprocessing training images"):
images_batch = [image for image in images[bs:bs+batch_size]]
images_batch = self.preprocess_image(images_batch)['pixel_values']
sequence_classifier_output = self.forward(images_batch, None)
# Get max prob
probs = sequence_classifier_output.logits.softmax(dim=-1).tolist()
classes = np.argmax(probs, axis=1)
confidences = np.max(probs, axis=1)
classes_list.extend(classes)
confidence_list.extend(confidences)
if classes_names:
classes_list = self.label_encoder.inverse_transform(classes_list)
if return_probabilities:
return classes_list, confidence_list, probs
return classes_list, confidence_list
def save(self, path):
logger.info("Saving model")
os.makedirs(path, exist_ok=True)
torch.save(self.state_dict(), path + "/model.pt")
# Save label encoder
np.save(path + "/label_encoder.npy", self.label_encoder.classes_)
def load(self, path):
logger.info("Loading model")
# Load label encoder
# Check if label encoder and model exists
if not os.path.exists(path + "/label_encoder.npy") or not os.path.exists(path + "/model.pt"):
logger.warning("Label encoder or model not found")
return
self.label_encoder.classes_ = np.load(path + "/label_encoder.npy")
# Reload classifier layer
self.classifier = nn.Linear(self.vit.config.hidden_size, len(self.label_encoder.classes_))
self.load_state_dict(torch.load(path + "/model.pt", map_location=self.device))
self.vit.to(self.device)
self.vit.eval()
self.to(self.device)
self.eval()
def evaluate(self, images, labels):
logger.info("Evaluating")
labels = self.label_encoder.transform(labels)
# Predict
y_pred, _ = self.predict(images, classes_names=False)
# Evaluate
metrics = metric.compute(predictions=y_pred, references=labels)
f1 = f1_score.compute(predictions=y_pred, references=labels, average="macro")
print(classification_report(labels, y_pred, labels=[i for i in range(len(self.label_encoder.classes_))], target_names=self.label_encoder.classes_))
print(f"Accuracy: {metrics['accuracy']}")
print(f"F1: {f1}")
def partial_fit(self, images, labels, save_model_path='new_model', num_epochs=10):
logger.info("Partial fitting")
# Freeze ViT model but last layer
# params = [param for param in self.vit.parameters()]
# for param in params[:-1]:
# param.requires_grad = False
# Model in training mode
self.vit.train()
self.train()
re_training(images, labels, self, save_model_path, num_epochs)
self.load(save_model_path)
self.vit.eval()
self.eval()
self.evaluate(images, labels)
def __load_from_path(self, path, num_per_label=None):
images = []
labels = []
for label in os.listdir(path):
count = 0
label_folder_path = os.path.join(path, label)
for image_file in tqdm(os.listdir(label_folder_path), desc="Resizing images for label {}".format(label)):
file_path = os.path.join(label_folder_path, image_file)
try:
image = Image.open(file_path)
image_shape = (self.feature_extractor.size, self.feature_extractor.size)
if image.size != image_shape:
image = image.resize(image_shape)
images.append(image.convert('RGB'))
labels.append(label)
count += 1
except Exception as e:
print(f"ERROR - Could not resize image {file_path} - {e}")
if num_per_label is not None and count >= num_per_label:
break
return images, labels
def retrain_from_path(self,
path='./data/feedback',
num_per_label=None,
save_model_path='new_model',
remove_path=False,
num_epochs=10,
save_new_data=data_path + '/new_data'):
logger.info("Retraining from path")
# Load path
images, labels = self.__load_from_path(path, num_per_label)
# Retrain
self.partial_fit(images, labels, save_model_path, num_epochs)
# Save new data
if save_new_data is not None:
logger.info("Saving new data")
for i ,(image, label) in enumerate(zip(images, labels)):
label_path = os.path.join(save_new_data, label)
os.makedirs(label_path, exist_ok=True)
image.save(os.path.join(label_path, str(int(time.time())) + f"_{i}.jpg"))
# Remove path folder
if remove_path:
logger.info("Removing feedback path")
shutil.rmtree(path)
def evaluate_from_path(self, path, num_per_label=None):
logger.info("Evaluating from path")
# Load images
images, labels = self.__load_from_path(path, num_per_label)
# Evaluate
self.evaluate(images, labels)