import gradio as gr import torch import os import kornia.filters import torchvision.transforms.functional import requests from PIL import Image from torchvision import transforms from operator import itemgetter import pickle import io from skimage.transform import resize from utils_functions.imports import * from util_models.resnet_with_skip import * from util_models.densenet_with_skip import * from util_models.glyphnet_with_skip import * def create_retrieval_figure(res): fig = plt.figure(figsize=[10 * 3, 10 * 3]) cols = 5 rows = 2 ax_query = fig.add_subplot(rows, 1, 1) plt.rcParams['figure.facecolor'] = 'white' plt.axis('off') ax_query.set_title('Top 10 most similar scarabs', fontsize=40) names = "" for i, image in zip(range(len(res)), res): current_image_path = image if i==0: continue if i < 11: image = cv2.imread(current_image_path) # image_resized = cv2.resize(image, (224, 224), interpolation=cv2.INTER_LINEAR) ax = fig.add_subplot(rows, cols, i) plt.axis('off') plt.imshow(cv2.cvtColor(image, cv2.COLOR_BGR2RGB)) item_uuid = current_image_path.split("/")[4].split("_photoUUID")[0].split("itemUUID_")[1] ax.set_title('Top {}'.format(i), fontsize=40) names = names + "Top " + str(i) + " item UUID is " + item_uuid + "\n" # img_buf = io.BytesIO() # plt.savefig(img_buf, format='png') # im_fig = Image.open(img_buf) # img_buf.close() # return im_fig return fig, names def knn_calc(image_name, query_feature, features): current_image_feature = features[image_name].to(device) criterion = torch.nn.CosineSimilarity(dim=1) dist = criterion(query_feature, current_image_feature).mean() dist = -dist.item() return dist def return_all_features(model_test, query_images_paths, glyph = False): model_test.eval() device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") model_test.to(device) features = dict() i = 0 transform = transforms.Compose([ transforms.RandomApply([transforms.ToPILImage(),], p=1), transforms.Resize((224, 224)), transforms.Grayscale(num_output_channels=3), transforms.ToTensor(), transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]) ]) gray_scale = transforms.Grayscale(num_output_channels=1) with torch.no_grad(): for image_path in query_images_paths: print(i) i = i + 1 # if check_image_label(image_path, labels_dict) is not None: img = cv2.imread(image_path) img = transform(img) # img = transforms.Grayscale(num_output_channels=1)(img).to(device) img = img.unsqueeze(0).contiguous().to(device) if glyph: img = gray_scale(img) current_image_features = model_test(img) # current_image_features, _, _, _ = model_test(x1=img, x2=img) features[image_path] = current_image_features # if i % 5 == 0: # print("Finished embedding of {} images".format(i)) del current_image_features torch.cuda.empty_cache() return features device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") # device = 'cpu' experiment = "experiment_0" checkpoint_path = os.path.join("../shapes_classification/checkpoints/" "50_50_pretrained_resnet101_experiment_0_train_images_with_drawings_batch_8_10:29:06/" + "experiment_0_last_auto_model.pth.tar") checkpoint_path = "multi_label.pth.tar" resnet = models.resnet101(pretrained=True) num_ftrs = resnet.fc.in_features resnet.fc = nn.Linear(num_ftrs, 13) model = Resnet_with_skip(resnet).to(device) checkpoint = torch.load(checkpoint_path, map_location="cpu") model.load_state_dict(checkpoint) embedding_model_test = torch.nn.Sequential(*(list(model.children())[:-1])) embedding_model_test.to(device) periods_model = models.resnet101(pretrained=True) periods_model.fc = nn.Linear(num_ftrs, 5) periods_checkpoint = torch.load("periods.pth.tar", map_location="cpu") periods_model.load_state_dict(periods_checkpoint) periods_model.to(device) data_dir = "../cssl_dataset/all_image_base/1/" query_images_paths = [] for path in os.listdir(data_dir): query_images_paths.append(os.path.join(data_dir, path)) # features = return_all_features(embedding_model_test, query_images_paths) # with open('features.pkl', 'wb') as fp: # pickle.dump(features, fp) with open('features.pkl', 'rb') as fp: features = pickle.load(fp) model.eval() transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.Grayscale(num_output_channels=3), transforms.ToTensor(), transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]) ]) invTrans = transforms.Compose([transforms.Normalize(mean=[0., 0., 0.], std=[1 / 0.5, 1 / 0.5, 1 / 0.5]), transforms.Normalize(mean=[-0.5, -0.5, -0.5], std=[1., 1., 1.]), ]) labels = sorted(os.listdir("../cssl_dataset/shape_multi_label/photos")) periods_labels = ["MB1", "MB2", "LB", "Iron1", 'Iron2'] periods_model.eval() def predict(inp): image_tensor = transform(inp) image_tensor = image_tensor.to(device) with torch.no_grad(): classification, reconstruction = model(image_tensor.unsqueeze(0)) periods_classification = periods_model(image_tensor.unsqueeze(0)) recon_tensor = reconstruction[0].repeat(3, 1, 1) recon_tensor = invTrans(kornia.enhance.invert(recon_tensor)) plot_recon = recon_tensor.to("cpu").permute(1, 2, 0).detach().numpy() w, h = inp.size plot_recon = resize(plot_recon, (h, w)) m = nn.Sigmoid() y = m(classification) preds = [] for sample in y: for i in sample: if i >=0.8: preds.append(1) else: preds.append(0) # prediction = torch.tensor(preds).to(device) confidences = {} true_labels = "" for i in range(len(labels)): if preds[i]==1: if true_labels=="": true_labels = true_labels + labels[i] else: true_labels = true_labels + "&" + labels[i] confidences[true_labels] = torch.tensor(1.0).to(device) periods_prediction = torch.nn.functional.softmax(periods_classification[0], dim=0) periods_confidences = {periods_labels[i]: periods_prediction[i] for i in range(len(periods_labels))} feature = embedding_model_test(image_tensor.unsqueeze(0)).to(device) dists = dict() with torch.no_grad(): for i, image_name in enumerate(query_images_paths): dist = knn_calc(image_name, feature, features) dists[image_name] = dist res = dict(sorted(dists.items(), key=itemgetter(1))) fig, names = create_retrieval_figure(res) return fig, names, plot_recon, confidences, periods_confidences gr.Interface(fn=predict, inputs=gr.Image(type="pil"), outputs=['plot', 'text', "image", gr.Label(num_top_classes=1), gr.Label(num_top_classes=1)], ).launch(share=True)