CHSTR commited on
Commit
09892bf
·
1 Parent(s): 8d750fc

Se utiliza el dataset desde hugginface

Browse files
__pycache__/utils.cpython-310.pyc ADDED
Binary file (3.75 kB). View file
 
app.py CHANGED
@@ -1,90 +1,85 @@
1
  import os
2
-
3
  import streamlit as st
4
  from io import BytesIO
5
- import base64
6
  from multiprocessing.dummy import Pool
7
- from PIL import Image, ImageDraw, ImageOps
8
-
9
  import torch
10
  from torchvision import transforms
11
-
12
- # sketches
13
  from streamlit_drawable_canvas import st_canvas
14
  from src.model_LN_prompt import Model
15
-
16
-
17
- import pickle as pkl
18
  from html import escape
 
19
  from huggingface_hub import hf_hub_download, login
20
  from datasets import load_dataset
21
 
22
- token = os.getenv("HUGGINGFACE_TOKEN")
23
 
24
- # Autentica usando el token
25
- login(token=token, add_to_git_credential=True)
26
 
27
- # Variables
28
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
29
- print(f"Device: {device}")
30
  HEIGHT = 200
31
- N_RESULTS = 15
32
  color = st.get_option("theme.primaryColor")
33
  if color is None:
34
  color = (0, 0, 255)
35
  else:
36
  color = tuple(int(color.lstrip("#")[i: i + 2], 16) for i in (0, 2, 4))
37
 
 
38
  @st.cache_resource
39
- def load():
40
- print("Cargando todo...")
 
 
 
 
 
 
 
 
 
41
  dataset = load_dataset("CHSTR/ecommerce")
42
  path_images = "/".join(dataset['validation']
43
  ['image'][0].filename.split("/")[:-3]) + "/"
44
- print(f"Directorio de imágenes: {path_images}")
45
 
46
- # Descargar el modelo desde Hugging Face
47
  path_model = hf_hub_download(
48
  repo_id="CHSTR/Ecommerce", filename="dinov2_ecommerce.ckpt")
49
- print(f"Archivo del modelo descargado en: {path_model}")
50
 
51
- # Cargar el modelo
52
- model = Model()
53
  model_checkpoint = torch.load(path_model, map_location=device)
54
  model.load_state_dict(model_checkpoint['state_dict'])
55
  model.eval()
56
- # model.to(device)
57
- print("Modelo cargado exitosamente")
58
 
59
- # Descargar y cargar los embeddings desde Hugging Face
60
  embeddings_file = hf_hub_download(
61
  repo_id="CHSTR/Ecommerce", filename="ecommerce_demo.pkl")
62
- print(f"Archivo de embeddings descargado en: {embeddings_file}")
63
 
64
  embeddings = {
65
  0: pkl.load(open(embeddings_file, "rb")),
66
  1: pkl.load(open(embeddings_file, "rb"))
67
  }
68
 
69
- # Actualizar los paths de las imágenes en los embeddings
70
- for i in range(len(embeddings[0])):
71
- embeddings[0][i] = (embeddings[0][i][0], path_images +
72
- "/".join(embeddings[0][i][1].split("/")[-3:]))
73
- # print(embeddings[0][i])
74
-
75
- for i in range(len(embeddings[1])):
76
- embeddings[1][i] = (embeddings[1][i][0], path_images +
77
- "/".join(embeddings[1][i][1].split("/")[-3:]))
78
 
79
  return model, path_images, embeddings
80
 
81
- def compute_sketch(sketch):
 
82
  with torch.no_grad():
83
- sketch_feat = model(sketch.to(device), dtype='sketch')
84
  return sketch_feat
85
 
86
- def image_search(query, corpus, n_results=N_RESULTS):
87
- query_embedding = compute_sketch(query)
 
88
  corpus_id = 0 if corpus == "Unsplash" else 1
89
  image_features = torch.tensor(
90
  [item[0] for item in embeddings[corpus_id]]).to(device)
@@ -93,7 +88,6 @@ def image_search(query, corpus, n_results=N_RESULTS):
93
  _, max_indices = torch.topk(
94
  dot_product, n_results, dim=0, largest=True, sorted=True)
95
 
96
- # Diccionario para mapear los paths a labels
97
  path_to_label = {path: idx for idx,
98
  (_, path) in enumerate(embeddings[corpus_id])}
99
  label_to_path = {idx: path for path, idx in path_to_label.items()}
@@ -101,14 +95,14 @@ def image_search(query, corpus, n_results=N_RESULTS):
101
  [path_to_label[item[1]] for item in embeddings[corpus_id]]).to(device)
102
 
103
  return [
104
- (
105
- label_to_path[i],
106
- )
107
  for i in label_of_images[max_indices].cpu().numpy().tolist()
108
- ], dot_product[max_indices]
109
 
110
 
111
- def make_square(img, fill_color=(255, 255, 255)):
 
 
112
  x, y = img.size
113
  size = max(x, y)
114
  new_img = Image.new("RGB", (x, y), fill_color)
@@ -118,18 +112,12 @@ def make_square(img, fill_color=(255, 255, 255)):
118
 
119
  @st.cache_data
120
  def get_images(paths):
121
- def process_image(path):
122
- return make_square(Image.open(path))
123
-
124
- processed = Pool(N_RESULTS).map(process_image, paths)
125
- imgs, xs, ys = [], [], []
126
- for img, x, y in processed:
127
- imgs.append(img)
128
- xs.append(x)
129
- ys.append(y)
130
- return imgs, xs, ys
131
 
132
 
 
133
  def convert_pil_to_base64(image):
134
  img_buffer = BytesIO()
135
  image.save(img_buffer, format="JPEG")
@@ -138,21 +126,6 @@ def convert_pil_to_base64(image):
138
  return base64_str
139
 
140
 
141
- def draw_reshape_encode(img, boxes, x, y):
142
- boxes = [boxes.tolist()]
143
- image = img.copy()
144
- draw = ImageDraw.Draw(image)
145
- new_x, new_y = int(x * HEIGHT / y), HEIGHT
146
- for box in boxes:
147
- print("box:", box)
148
- draw.rectangle(
149
- # (x_min, y_min, x_max, y_max)
150
- [(box[0], box[1]), (box[2], box[3])],
151
- outline=color, # Box color
152
- width=7 # Box width
153
- )
154
-
155
-
156
  def get_html(url_list, encoded_images):
157
  html = "<div style='margin-top: 20px; max-width: 1200px; display: flex; flex-wrap: wrap; justify-content: space-evenly'>"
158
  for i in range(len(url_list)):
@@ -165,63 +138,40 @@ def get_html(url_list, encoded_images):
165
  return html
166
 
167
 
168
- description = """
169
- # Sketch-based Image Retrieval (SBIR)
170
- """
171
-
172
- div_style = {
173
- "display": "flex",
174
- "justify-content": "center",
175
- "flex-wrap": "wrap",
176
- }
177
-
178
 
179
- model, path_images, embeddings = load()
 
180
 
 
181
 
182
- def main():
183
 
184
- print("Cargando modelos...")
185
  stroke_width = st.sidebar.slider("Stroke width: ", 1, 25, 5)
186
 
 
187
  st.markdown(
188
  """
189
  <style>
190
- .block-container{
191
- max-width: 1200px;
192
- }
193
- div.row-widget > div{
194
- flex-direction: row;
195
- display: flex;
196
- justify-content: center;
197
- }
198
- div.row-widget.stRadio > div > label{
199
- margin-left: 5px;
200
- margin-right: 5px;
201
- }
202
- .row-widget {
203
- margin-top: -25px;
204
- }
205
- section > div:first-child {
206
- padding-top: 30px;
207
- }
208
- div.appview-container > section:first-child{
209
- max-width: 320px;
210
- }
211
- #MainMenu {
212
- visibility: hidden;
213
- }
214
- .stMarkdown {
215
- display: grid;
216
- place-items: center;
217
- }
218
  </style>
219
  """,
220
  unsafe_allow_html=True,
221
  )
222
- st.sidebar.markdown(description)
223
 
224
- st.title("SBIR App")
225
  _, col, _ = st.columns((1, 1, 1))
226
  with col:
227
  canvas_result = st_canvas(
@@ -233,13 +183,12 @@ def main():
233
  key="color_annotation_app",
234
  )
235
 
236
- st.columns((1, 3, 1))
237
  corpus = ["Ecommerce"]
 
238
 
239
  if canvas_result.image_data is not None:
240
  draw = Image.fromarray(canvas_result.image_data.astype("uint8"))
241
  draw = ImageOps.pad(draw.convert("RGB"), size=(224, 224))
242
- draw.save("draw.jpg")
243
 
244
  draw_tensor = transforms.ToTensor()(draw)
245
  draw_tensor = transforms.Resize((224, 224))(draw_tensor)
@@ -248,20 +197,19 @@ def main():
248
  )(draw_tensor)
249
  draw_tensor = draw_tensor.unsqueeze(0)
250
 
251
- retrieved, _ = image_search(draw_tensor, corpus)
 
252
  imgs, xs, ys = get_images([x[0] for x in retrieved])
 
253
  encoded_images = []
254
  for image_idx in range(len(imgs)):
255
  img0, x, y = imgs[image_idx], xs[image_idx], ys[image_idx]
256
-
257
  new_x, new_y = int(x * HEIGHT / y), HEIGHT
258
-
259
  encoded_images.append(convert_pil_to_base64(
260
  img0.resize((new_x, new_y))))
 
261
  st.markdown(get_html(retrieved, encoded_images),
262
  unsafe_allow_html=True)
263
- else:
264
- return
265
 
266
 
267
  if __name__ == "__main__":
 
1
  import os
 
2
  import streamlit as st
3
  from io import BytesIO
 
4
  from multiprocessing.dummy import Pool
5
+ import base64
6
+ from PIL import Image, ImageOps
7
  import torch
8
  from torchvision import transforms
 
 
9
  from streamlit_drawable_canvas import st_canvas
10
  from src.model_LN_prompt import Model
 
 
 
11
  from html import escape
12
+ import pickle as pkl
13
  from huggingface_hub import hf_hub_download, login
14
  from datasets import load_dataset
15
 
 
16
 
17
+ if 'initialized' not in st.session_state:
18
+ st.session_state.initialized = False
19
 
 
20
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
 
21
  HEIGHT = 200
22
+ N_RESULTS = 20
23
  color = st.get_option("theme.primaryColor")
24
  if color is None:
25
  color = (0, 0, 255)
26
  else:
27
  color = tuple(int(color.lstrip("#")[i: i + 2], 16) for i in (0, 2, 4))
28
 
29
+
30
  @st.cache_resource
31
+ def initialize_huggingface():
32
+ token = os.getenv("HUGGINGFACE_TOKEN")
33
+ if token:
34
+ login(token=token)
35
+ else:
36
+ st.error("HUGGINGFACE_TOKEN not found in environment variables")
37
+
38
+
39
+ @st.cache_resource
40
+ def load_model_and_data():
41
+ print("Loading everything...")
42
  dataset = load_dataset("CHSTR/ecommerce")
43
  path_images = "/".join(dataset['validation']
44
  ['image'][0].filename.split("/")[:-3]) + "/"
 
45
 
46
+ # Download model
47
  path_model = hf_hub_download(
48
  repo_id="CHSTR/Ecommerce", filename="dinov2_ecommerce.ckpt")
 
49
 
50
+ # Load model
51
+ model = Model().to(device)
52
  model_checkpoint = torch.load(path_model, map_location=device)
53
  model.load_state_dict(model_checkpoint['state_dict'])
54
  model.eval()
 
 
55
 
56
+ # Download and load embeddings
57
  embeddings_file = hf_hub_download(
58
  repo_id="CHSTR/Ecommerce", filename="ecommerce_demo.pkl")
 
59
 
60
  embeddings = {
61
  0: pkl.load(open(embeddings_file, "rb")),
62
  1: pkl.load(open(embeddings_file, "rb"))
63
  }
64
 
65
+ # Update image paths
66
+ for corpus_id in [0, 1]:
67
+ embeddings[corpus_id] = [
68
+ (emb[0], path_images + "/".join(emb[1].split("/")[-3:]))
69
+ for emb in embeddings[corpus_id]
70
+ ]
 
 
 
71
 
72
  return model, path_images, embeddings
73
 
74
+
75
+ def compute_sketch(_sketch, model):
76
  with torch.no_grad():
77
+ sketch_feat = model(_sketch.to(device), dtype='sketch')
78
  return sketch_feat
79
 
80
+
81
+ def image_search(_query, corpus, model, embeddings, n_results=N_RESULTS):
82
+ query_embedding = compute_sketch(_query, model)
83
  corpus_id = 0 if corpus == "Unsplash" else 1
84
  image_features = torch.tensor(
85
  [item[0] for item in embeddings[corpus_id]]).to(device)
 
88
  _, max_indices = torch.topk(
89
  dot_product, n_results, dim=0, largest=True, sorted=True)
90
 
 
91
  path_to_label = {path: idx for idx,
92
  (_, path) in enumerate(embeddings[corpus_id])}
93
  label_to_path = {idx: path for path, idx in path_to_label.items()}
 
95
  [path_to_label[item[1]] for item in embeddings[corpus_id]]).to(device)
96
 
97
  return [
98
+ (label_to_path[i],)
 
 
99
  for i in label_of_images[max_indices].cpu().numpy().tolist()
100
+ ], dot_product[max_indices]
101
 
102
 
103
+ @st.cache_data
104
+ def make_square(img_path, fill_color=(255, 255, 255)):
105
+ img = Image.open(img_path)
106
  x, y = img.size
107
  size = max(x, y)
108
  new_img = Image.new("RGB", (x, y), fill_color)
 
112
 
113
  @st.cache_data
114
  def get_images(paths):
115
+ processed = [make_square(path) for path in paths]
116
+ imgs, xs, ys = zip(*processed)
117
+ return list(imgs), list(xs), list(ys)
 
 
 
 
 
 
 
118
 
119
 
120
+ @st.cache_data
121
  def convert_pil_to_base64(image):
122
  img_buffer = BytesIO()
123
  image.save(img_buffer, format="JPEG")
 
126
  return base64_str
127
 
128
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
129
  def get_html(url_list, encoded_images):
130
  html = "<div style='margin-top: 20px; max-width: 1200px; display: flex; flex-wrap: wrap; justify-content: space-evenly'>"
131
  for i in range(len(url_list)):
 
138
  return html
139
 
140
 
141
+ def main():
142
+ if not st.session_state.initialized:
143
+ initialize_huggingface()
144
+ st.session_state.model, st.session_state.path_images, st.session_state.embeddings = load_model_and_data()
145
+ st.session_state.initialized = True
 
 
 
 
 
146
 
147
+ description = """
148
+ # Self-Supervised Sketch-based Image Retrieval (S3BIR)
149
 
150
+ Our approaches, S3BIR-CLIP and S3BIR-DINOv2, can produce a bimodal sketch-photo feature space from unpaired data without explicit sketch-photo pairs. Our experiments perform outstandingly in three diverse public datasets where the models are trained without real sketches.
151
 
152
+ """
153
 
154
+ st.sidebar.markdown(description)
155
  stroke_width = st.sidebar.slider("Stroke width: ", 1, 25, 5)
156
 
157
+ # styles
158
  st.markdown(
159
  """
160
  <style>
161
+ .block-container{ max-width: 1200px; }
162
+ div.row-widget > div{ flex-direction: row; display: flex; justify-content: center; color: white; }
163
+ div.row-widget.stRadio > div > label{ margin-left: 5px; margin-right: 5px; }
164
+ .row-widget { margin-top: -25px; }
165
+ section > div:first-child { padding-top: 30px; }
166
+ div.appview-container > section:first-child{ max-width: 320px; }
167
+ #MainMenu { visibility: hidden; }
168
+ .stMarkdown { display: grid; place-items: center; }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
169
  </style>
170
  """,
171
  unsafe_allow_html=True,
172
  )
 
173
 
174
+ st.title("S3BIR App")
175
  _, col, _ = st.columns((1, 1, 1))
176
  with col:
177
  canvas_result = st_canvas(
 
183
  key="color_annotation_app",
184
  )
185
 
 
186
  corpus = ["Ecommerce"]
187
+ st.columns((1, 3, 1))
188
 
189
  if canvas_result.image_data is not None:
190
  draw = Image.fromarray(canvas_result.image_data.astype("uint8"))
191
  draw = ImageOps.pad(draw.convert("RGB"), size=(224, 224))
 
192
 
193
  draw_tensor = transforms.ToTensor()(draw)
194
  draw_tensor = transforms.Resize((224, 224))(draw_tensor)
 
197
  )(draw_tensor)
198
  draw_tensor = draw_tensor.unsqueeze(0)
199
 
200
+ retrieved, _ = image_search(
201
+ draw_tensor, corpus[0], st.session_state.model, st.session_state.embeddings)
202
  imgs, xs, ys = get_images([x[0] for x in retrieved])
203
+
204
  encoded_images = []
205
  for image_idx in range(len(imgs)):
206
  img0, x, y = imgs[image_idx], xs[image_idx], ys[image_idx]
 
207
  new_x, new_y = int(x * HEIGHT / y), HEIGHT
 
208
  encoded_images.append(convert_pil_to_base64(
209
  img0.resize((new_x, new_y))))
210
+
211
  st.markdown(get_html(retrieved, encoded_images),
212
  unsafe_allow_html=True)
 
 
213
 
214
 
215
  if __name__ == "__main__":
src/__pycache__/model_LN_prompt.cpython-310.pyc CHANGED
Binary files a/src/__pycache__/model_LN_prompt.cpython-310.pyc and b/src/__pycache__/model_LN_prompt.cpython-310.pyc differ
 
src/__pycache__/options.cpython-310.pyc CHANGED
Binary files a/src/__pycache__/options.cpython-310.pyc and b/src/__pycache__/options.cpython-310.pyc differ
 
src/model_LN_prompt.py CHANGED
@@ -1,15 +1,9 @@
1
- import numpy as np
2
  import torch
3
  import torch.nn as nn
4
  import torch.nn.functional as F
5
- from torchmetrics.functional import retrieval_average_precision
6
  import pytorch_lightning as pl
7
 
8
  from src.dinov2.models.vision_transformer import vit_base
9
-
10
- from functools import partial
11
-
12
- # from src.clip import clip
13
  from src.options import opts
14
 
15
  def freeze_model(m):
@@ -31,23 +25,11 @@ class Model(pl.LightningModule):
31
  self.opts = opts
32
 
33
  self.dino = vit_base(patch_size=14, block_chunks=0, init_values=1.0)
34
- print("self.dino", self.dino)
35
 
36
  # Prompt Engineering
37
  self.sk_prompt = nn.Parameter(torch.randn(self.opts.n_prompts, self.opts.prompt_dim))
38
  self.img_prompt = nn.Parameter(torch.randn(self.opts.n_prompts, self.opts.prompt_dim))
39
 
40
- self.distance_fn = lambda x, y: 1.0 - F.cosine_similarity(x, y)
41
- self.loss_fn_triplet = nn.TripletMarginWithDistanceLoss(
42
- distance_function=self.distance_fn, margin=0.2)
43
-
44
- self.emb_cos_loss = nn.CosineEmbeddingLoss(margin=0.2)
45
-
46
- self.loss_kl = nn.KLDivLoss(reduction="batchmean", log_target=True)
47
-
48
- self.best_metric = -1e3
49
- # normalization layer for the representations z1 and z2
50
- # self.bn = nn.BatchNorm1d(self.opts.prompt_dim, affine=False)
51
 
52
  def configure_optimizers(self):
53
  if self.opts.model_type == 'one_encoder':
 
 
1
  import torch
2
  import torch.nn as nn
3
  import torch.nn.functional as F
 
4
  import pytorch_lightning as pl
5
 
6
  from src.dinov2.models.vision_transformer import vit_base
 
 
 
 
7
  from src.options import opts
8
 
9
  def freeze_model(m):
 
25
  self.opts = opts
26
 
27
  self.dino = vit_base(patch_size=14, block_chunks=0, init_values=1.0)
 
28
 
29
  # Prompt Engineering
30
  self.sk_prompt = nn.Parameter(torch.randn(self.opts.n_prompts, self.opts.prompt_dim))
31
  self.img_prompt = nn.Parameter(torch.randn(self.opts.n_prompts, self.opts.prompt_dim))
32
 
 
 
 
 
 
 
 
 
 
 
 
33
 
34
  def configure_optimizers(self):
35
  if self.opts.model_type == 'one_encoder':
src/options.py CHANGED
@@ -1,18 +1,17 @@
1
  import argparse
2
 
3
- parser = argparse.ArgumentParser(description='Sketch-based OD')
4
 
5
- parser.add_argument('--exp_name', type=str, default='LN_prompt')
6
 
7
  # ----------------------
8
  # Training Params
9
  # ----------------------
10
 
11
- parser.add_argument('--clip_lr', type=float, default=1e-4)
12
- parser.add_argument('--clip_LN_lr', type=float, default=1e-6)
13
  parser.add_argument('--prompt_lr', type=float, default=1e-4)
14
  parser.add_argument('--linear_lr', type=float, default=1e-4)
15
- parser.add_argument('--model_type', type=str, default='one_encoder', choices=['one_encoder', 'two_encoder'])
16
 
17
  # ----------------------
18
  # ViT Prompt Parameters
 
1
  import argparse
2
 
3
+ parser = argparse.ArgumentParser(description='S3BIR')
4
 
5
+ parser.add_argument('--exp_name', type=str, default='DINOv2_prompt')
6
 
7
  # ----------------------
8
  # Training Params
9
  # ----------------------
10
 
11
+ parser.add_argument('--dinov2_lr', type=float, default=1e-4)
12
+ parser.add_argument('--dinov2_LN_lr', type=float, default=1e-6)
13
  parser.add_argument('--prompt_lr', type=float, default=1e-4)
14
  parser.add_argument('--linear_lr', type=float, default=1e-4)
 
15
 
16
  # ----------------------
17
  # ViT Prompt Parameters