danibalcells commited on
Commit
cef3c44
·
1 Parent(s): d80adbc

First working Gradio app

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. app.py +110 -0
  2. dataset2-3/crops/Abyssinian_1.jpg +0 -0
  3. dataset2-3/crops/Abyssinian_10.jpg +0 -0
  4. dataset2-3/crops/Abyssinian_100.jpg +0 -0
  5. dataset2-3/crops/Abyssinian_101.jpg +0 -0
  6. dataset2-3/crops/Abyssinian_102.jpg +0 -0
  7. dataset2-3/crops/Abyssinian_103.jpg +0 -0
  8. dataset2-3/crops/Abyssinian_105.jpg +0 -0
  9. dataset2-3/crops/Abyssinian_106.jpg +0 -0
  10. dataset2-3/crops/Abyssinian_107.jpg +0 -0
  11. dataset2-3/crops/Abyssinian_108.jpg +0 -0
  12. dataset2-3/crops/Abyssinian_109.jpg +0 -0
  13. dataset2-3/crops/Abyssinian_11.jpg +0 -0
  14. dataset2-3/crops/Abyssinian_110.jpg +0 -0
  15. dataset2-3/crops/Abyssinian_111.jpg +0 -0
  16. dataset2-3/crops/Abyssinian_112.jpg +0 -0
  17. dataset2-3/crops/Abyssinian_113.jpg +0 -0
  18. dataset2-3/crops/Abyssinian_114.jpg +0 -0
  19. dataset2-3/crops/Abyssinian_115.jpg +0 -0
  20. dataset2-3/crops/Abyssinian_116.jpg +0 -0
  21. dataset2-3/crops/Abyssinian_117.jpg +0 -0
  22. dataset2-3/crops/Abyssinian_118.jpg +0 -0
  23. dataset2-3/crops/Abyssinian_119.jpg +0 -0
  24. dataset2-3/crops/Abyssinian_12.jpg +0 -0
  25. dataset2-3/crops/Abyssinian_120.jpg +0 -0
  26. dataset2-3/crops/Abyssinian_121.jpg +0 -0
  27. dataset2-3/crops/Abyssinian_122.jpg +0 -0
  28. dataset2-3/crops/Abyssinian_123.jpg +0 -0
  29. dataset2-3/crops/Abyssinian_124.jpg +0 -0
  30. dataset2-3/crops/Abyssinian_125.jpg +0 -0
  31. dataset2-3/crops/Abyssinian_126.jpg +0 -0
  32. dataset2-3/crops/Abyssinian_127.jpg +0 -0
  33. dataset2-3/crops/Abyssinian_128.jpg +0 -0
  34. dataset2-3/crops/Abyssinian_129.jpg +0 -0
  35. dataset2-3/crops/Abyssinian_13.jpg +0 -0
  36. dataset2-3/crops/Abyssinian_130.jpg +0 -0
  37. dataset2-3/crops/Abyssinian_131.jpg +0 -0
  38. dataset2-3/crops/Abyssinian_132.jpg +0 -0
  39. dataset2-3/crops/Abyssinian_133.jpg +0 -0
  40. dataset2-3/crops/Abyssinian_134.jpg +0 -0
  41. dataset2-3/crops/Abyssinian_135.jpg +0 -0
  42. dataset2-3/crops/Abyssinian_136.jpg +0 -0
  43. dataset2-3/crops/Abyssinian_137.jpg +0 -0
  44. dataset2-3/crops/Abyssinian_138.jpg +0 -0
  45. dataset2-3/crops/Abyssinian_139.jpg +0 -0
  46. dataset2-3/crops/Abyssinian_14.jpg +0 -0
  47. dataset2-3/crops/Abyssinian_140.jpg +0 -0
  48. dataset2-3/crops/Abyssinian_141.jpg +0 -0
  49. dataset2-3/crops/Abyssinian_142.jpg +0 -0
  50. dataset2-3/crops/Abyssinian_143.jpg +0 -0
app.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pickle
2
+ from functools import partial
3
+ from pathlib import Path
4
+ from io import BytesIO
5
+
6
+
7
+ import matplotlib.pyplot as plt
8
+ import torch
9
+ from fastai.vision.all import *
10
+ import gradio as gr
11
+
12
+
13
+ from feature_extractor import FeatureExtractor
14
+
15
+
16
+ MODEL_NAME = Path('model_dataset2-3_smallcrop_tinyresize.pkl')
17
+ FEATURES_NAME = Path('features_dataset2-3_smallcrop_tinyresize.pkl')
18
+
19
+ def get_label(file_path):
20
+ return os.path.basename(file_path).split('_')[0]
21
+
22
+ def loss_func(x, y):
23
+ return torch.tensor(0.)
24
+
25
+ def get_image_features(input_image, feature_extractor):
26
+ with feature_extractor.no_bar(), feature_extractor.no_logging():
27
+ _, features, _ = feature_extractor.predict(input_image)
28
+ return features
29
+
30
+ def get_similar_image(input_image, feature_extractor, features_dict):
31
+ # Convert the features dictionary to a list of tuples
32
+ features_list = list(features_dict.items())
33
+
34
+ # Extract the image paths and features
35
+ image_paths, feature_tensors = zip(*features_list)
36
+
37
+ # Convert the features to a PyTorch tensor
38
+ features_tensor = torch.stack(feature_tensors)
39
+
40
+ # Now, to compute the cosine similarity between the user's input image and all other images:
41
+ user_features = get_image_features(input_image, feature_extractor)
42
+ user_features = user_features.view(1, -1) # Reshape to 2D tensor
43
+
44
+ # Compute cosine similarity
45
+ similarity_scores = torch.nn.functional.cosine_similarity(user_features, features_tensor)
46
+
47
+ # Get the index of the most similar image
48
+ most_similar_index = torch.argmax(similarity_scores)
49
+
50
+ # Get the path of the most similar image
51
+ most_similar_image_path = image_paths[most_similar_index]
52
+ # Display the most similar image and the input image side by side
53
+ most_similar_image = PILImage.create(most_similar_image_path)
54
+ return most_similar_image
55
+
56
+ def plot_side_by_side(input_image, similar_image, show=True, save_path=None):
57
+ similar_image_thumb = similar_image.to_thumb(224)
58
+ user_image_thumb = input_image.to_thumb(224)
59
+ # Create a figure with two subplots
60
+ fig, (ax1, ax2) = plt.subplots(1, 2)
61
+
62
+ # Display the images
63
+ ax1.imshow(similar_image_thumb)
64
+ ax2.imshow(user_image_thumb)
65
+
66
+ # Optionally, remove the axes for a cleaner look
67
+ ax1.axis('off')
68
+ ax2.axis('off')
69
+
70
+ fig.suptitle('Is It Really Worth It?', fontsize=20, weight='bold')
71
+ if save_path:
72
+ plt.savefig(save_path)
73
+ plt.close()
74
+ if show:
75
+ plt.show()
76
+ # Convert the plot to a PIL Image
77
+ buf = BytesIO()
78
+ plt.savefig(buf, format='png')
79
+ plt.close(fig)
80
+ buf.seek(0)
81
+ result_image = Image.open(buf)
82
+
83
+ return result_image
84
+
85
+ def process_image(input_image, feature_extractor, features_dict, show=True, save_path=None):
86
+ similar_image = get_similar_image(input_image, feature_extractor, features_dict)
87
+ meme = plot_side_by_side(input_image, similar_image, show=show, save_path=save_path)
88
+ return meme
89
+
90
+ def load_model_and_features():
91
+ # Load the model
92
+ feature_extractor = load_learner(MODEL_NAME)
93
+
94
+ with open(FEATURES_NAME, 'rb') as f:
95
+ features_dict = pickle.load(f)
96
+
97
+ return feature_extractor, features_dict
98
+
99
+ def predict(input_image):
100
+ img = PILImage.create(input_image)
101
+ feature_extractor, features_dict = load_model_and_features()
102
+ return process_image(img, feature_extractor, features_dict, show=False)
103
+
104
+ iface = gr.Interface(
105
+ fn=predict,
106
+ inputs='image',
107
+ outputs='image',
108
+ )
109
+
110
+ iface.launch()
dataset2-3/crops/Abyssinian_1.jpg ADDED
dataset2-3/crops/Abyssinian_10.jpg ADDED
dataset2-3/crops/Abyssinian_100.jpg ADDED
dataset2-3/crops/Abyssinian_101.jpg ADDED
dataset2-3/crops/Abyssinian_102.jpg ADDED
dataset2-3/crops/Abyssinian_103.jpg ADDED
dataset2-3/crops/Abyssinian_105.jpg ADDED
dataset2-3/crops/Abyssinian_106.jpg ADDED
dataset2-3/crops/Abyssinian_107.jpg ADDED
dataset2-3/crops/Abyssinian_108.jpg ADDED
dataset2-3/crops/Abyssinian_109.jpg ADDED
dataset2-3/crops/Abyssinian_11.jpg ADDED
dataset2-3/crops/Abyssinian_110.jpg ADDED
dataset2-3/crops/Abyssinian_111.jpg ADDED
dataset2-3/crops/Abyssinian_112.jpg ADDED
dataset2-3/crops/Abyssinian_113.jpg ADDED
dataset2-3/crops/Abyssinian_114.jpg ADDED
dataset2-3/crops/Abyssinian_115.jpg ADDED
dataset2-3/crops/Abyssinian_116.jpg ADDED
dataset2-3/crops/Abyssinian_117.jpg ADDED
dataset2-3/crops/Abyssinian_118.jpg ADDED
dataset2-3/crops/Abyssinian_119.jpg ADDED
dataset2-3/crops/Abyssinian_12.jpg ADDED
dataset2-3/crops/Abyssinian_120.jpg ADDED
dataset2-3/crops/Abyssinian_121.jpg ADDED
dataset2-3/crops/Abyssinian_122.jpg ADDED
dataset2-3/crops/Abyssinian_123.jpg ADDED
dataset2-3/crops/Abyssinian_124.jpg ADDED
dataset2-3/crops/Abyssinian_125.jpg ADDED
dataset2-3/crops/Abyssinian_126.jpg ADDED
dataset2-3/crops/Abyssinian_127.jpg ADDED
dataset2-3/crops/Abyssinian_128.jpg ADDED
dataset2-3/crops/Abyssinian_129.jpg ADDED
dataset2-3/crops/Abyssinian_13.jpg ADDED
dataset2-3/crops/Abyssinian_130.jpg ADDED
dataset2-3/crops/Abyssinian_131.jpg ADDED
dataset2-3/crops/Abyssinian_132.jpg ADDED
dataset2-3/crops/Abyssinian_133.jpg ADDED
dataset2-3/crops/Abyssinian_134.jpg ADDED
dataset2-3/crops/Abyssinian_135.jpg ADDED
dataset2-3/crops/Abyssinian_136.jpg ADDED
dataset2-3/crops/Abyssinian_137.jpg ADDED
dataset2-3/crops/Abyssinian_138.jpg ADDED
dataset2-3/crops/Abyssinian_139.jpg ADDED
dataset2-3/crops/Abyssinian_14.jpg ADDED
dataset2-3/crops/Abyssinian_140.jpg ADDED
dataset2-3/crops/Abyssinian_141.jpg ADDED
dataset2-3/crops/Abyssinian_142.jpg ADDED
dataset2-3/crops/Abyssinian_143.jpg ADDED