CHSTR commited on
Commit
dd12453
1 Parent(s): 265ae36

Archivos necesarios

Browse files
Files changed (2) hide show
  1. app.py +348 -0
  2. requirements.txt +10 -0
app.py ADDED
@@ -0,0 +1,348 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from html import escape
2
+ import requests
3
+ from io import BytesIO
4
+ import base64
5
+ from multiprocessing.dummy import Pool
6
+ from PIL import Image, ImageDraw
7
+ import streamlit as st
8
+ import pandas as pd
9
+ import numpy as np
10
+ import torch
11
+ # from transformers import CLIPProcessor, CLIPModel
12
+ # from transformers import OwlViTProcessor, OwlViTForObjectDetection
13
+ # from transformers.image_utils import ImageFeatureExtractionMixin
14
+ import tokenizers
15
+
16
+ import pickle as pkl
17
+
18
+ # sketches
19
+ from streamlit_drawable_canvas import st_canvas
20
+ from PIL import Image, ImageOps
21
+ from torchvision import transforms
22
+
23
+
24
+ # model
25
+ import os
26
+ # No reconoce la carpeta que esta dos niveles abajo src
27
+ from src.model_LN_prompt import Model
28
+ from src.options import opts
29
+
30
+
31
+ DEBUG = False
32
+ if DEBUG:
33
+ MODEL = "vit-base-patch32"
34
+ else:
35
+ MODEL = "vit-large-patch14-336"
36
+ CLIP_MODEL = f"openai/clip-{MODEL}"
37
+ OWL_MODEL = f"google/owlvit-base-patch32"
38
+
39
+ if not DEBUG and torch.cuda.is_available():
40
+ device = torch.device("cuda")
41
+ else:
42
+ device = torch.device("cpu")
43
+
44
+ HEIGHT = 350
45
+ N_RESULTS = 5
46
+
47
+ from huggingface_hub import hf_hub_download,login
48
+
49
+ token = os.getenv("HUGGINGFACE_TOKEN")
50
+
51
+ # Autentica usando el token
52
+ login(token=token)
53
+
54
+
55
+ color = st.get_option("theme.primaryColor")
56
+ if color is None:
57
+ color = (0, 255, 0)
58
+ else:
59
+ color = tuple(int(color.lstrip("#")[i: i + 2], 16) for i in (0, 2, 4))
60
+
61
+ @st.cache_resource
62
+ def load():
63
+ path_images = 'data/doc_explore/DocExplore_images/'
64
+ path_model = hf_hub_download(repo_id="CHSTR/DocExplore", filename="epoch=16-mAP=0.66_triplet.ckpt")#"models/epoch=16-mAP=0.66_triplet.ckpt"
65
+
66
+ try:
67
+ model = Model().to(device)
68
+ model_checkpoint = torch.load(path_model) # 'model_60k_images_073.ckpt' -> modelo entrenado con 60k imagenes sin pidinet
69
+ model.load_state_dict(model_checkpoint['state_dict']) # 'modified_model_083.ckpt' -> modelo entrenado con 60k imagenes con pidinet
70
+ model.eval() # 'original_model_083.ckpt' -> modelo original entrenado con 60k imagenes con pidinet
71
+ print("Modelo cargado exitosamente")
72
+ except:
73
+ print("No se pudo cargar el modelo. Intenta nuevamente cambiando el argumento --model_type")
74
+ exit()
75
+
76
+ embeddings_file_1 = hf_hub_download(repo_id="CHSTR/DocExplore", filename="dino_flicker_docexplore_groundingDINO.pkl")
77
+ embeddings_file_0 = hf_hub_download(repo_id="CHSTR/DocExplore", filename="docexp_embeddings.pkl")
78
+
79
+ embeddings = {
80
+ 0: pkl.load(open(embeddings_file_0, "rb")),
81
+ 1: pkl.load(open(embeddings_file_1, "rb"))
82
+ }
83
+
84
+ # embeddings = {
85
+ # 0: pkl.load(open("docexp_embeddings.pkl", "rb")),
86
+ # 1: pkl.load(open("dino_flicker_docexplore_groundingDINO.pkl", "rb"))
87
+ # }
88
+
89
+ # Actualizar los paths de las imágenes en los embeddings
90
+ #for i in range(len(embeddings[0])):
91
+ # print(embeddings[0][i])
92
+ #embeddings[0][i] = (embeddings[0][i][0], path_images + "/".join(embeddings[0][i][1].split("/")[:-3]))
93
+
94
+ #for i in range(len(embeddings[1])):
95
+ # print(embeddings[1][i])
96
+ #embeddings[1][i] = (embeddings[1][i][0], path_images + "/".join(embeddings[1][i][1].split("/")[:-3]))
97
+
98
+ return model, path_images, embeddings
99
+
100
+ print("Cargando modelos...")
101
+ model, path_images, embeddings = load()
102
+ source = {0: "\nDocExplore SAM", 1: "\nDocExplore GroundingDINO"}
103
+
104
+ stroke_width = st.sidebar.slider("Stroke width: ", 1, 25, 5)
105
+
106
+ dataset_transforms = transforms.Compose([
107
+ transforms.Resize((224, 224)),
108
+ transforms.ToTensor(),
109
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
110
+ ])
111
+
112
+
113
+ def compute_text_embeddings(sketch):
114
+ with torch.no_grad():
115
+ sketch_feat = model(sketch.to(device), dtype='sketch')
116
+ return sketch_feat
117
+ # inputs = clip_processor(text=list_of_strings, return_tensors="pt", padding=True).to(
118
+ # device
119
+ # )
120
+ # with torch.no_grad():
121
+ # result = clip_model.get_text_features(**inputs).detach().cpu().numpy()
122
+ # return result / np.linalg.norm(result, axis=1, keepdims=True)
123
+ #return torch.randn(1, 768)
124
+
125
+
126
+ def image_search(query, corpus, n_results=N_RESULTS):
127
+ query_embedding = compute_text_embeddings(query)
128
+ corpus_id = 0 if corpus == "DocExplore SAM" else 1
129
+ image_features = torch.tensor([item[0] for item in embeddings[corpus_id]]).to(device)
130
+ bbox_of_images = torch.tensor([item[1] for item in embeddings[corpus_id]]).to(device)
131
+ label_of_images = torch.tensor([item[2] for item in embeddings[corpus_id]]).to(device)
132
+ dot_product = (image_features @ query_embedding.T)[:, 0]
133
+ _, max_indices = torch.topk(dot_product, n_results, dim=0, largest=True, sorted=True)
134
+
135
+ return [
136
+ (
137
+ path_images + "page" + str(i) + ".jpg",
138
+ )
139
+ for i in label_of_images[max_indices].cpu().numpy().tolist()
140
+ ], bbox_of_images[max_indices], dot_product[max_indices]
141
+
142
+
143
+ def make_square(img, fill_color=(255, 255, 255)):
144
+ x, y = img.size
145
+ size = max(x, y)
146
+ new_img = Image.new("RGB", (x, y), fill_color)
147
+ new_img.paste(img)
148
+ return new_img, x, y
149
+
150
+ @st.cache_data
151
+ def get_images(paths):
152
+ def process_image(path):
153
+ return make_square(Image.open(path))
154
+
155
+ processed = Pool(N_RESULTS).map(process_image, paths)
156
+ imgs, xs, ys = [], [], []
157
+ for img, x, y in processed:
158
+ imgs.append(img)
159
+ xs.append(x)
160
+ ys.append(y)
161
+ return imgs, xs, ys
162
+
163
+
164
+ def keep_best_boxes(boxes, scores, score_threshold=0.1, max_iou=0.8):
165
+ candidates = []
166
+ for box, score in zip(boxes, scores):
167
+ box = [round(i, 0) for i in box.tolist()]
168
+ if score >= score_threshold:
169
+ candidates.append((box, float(score)))
170
+
171
+ to_ignore = set()
172
+ for i in range(len(candidates) - 1):
173
+ if i in to_ignore:
174
+ continue
175
+ for j in range(i + 1, len(candidates)):
176
+ if j in to_ignore:
177
+ continue
178
+ xmin1, ymin1, xmax1, ymax1 = candidates[i][0]
179
+ xmin2, ymin2, xmax2, ymax2 = candidates[j][0]
180
+ if xmax1 < xmin2 or xmax2 < xmin1 or ymax1 < ymin2 or ymax2 < ymin1:
181
+ continue
182
+ else:
183
+ xmin_inter, xmax_inter = sorted(
184
+ [xmin1, xmax1, xmin2, xmax2])[1:3]
185
+ ymin_inter, ymax_inter = sorted(
186
+ [ymin1, ymax1, ymin2, ymax2])[1:3]
187
+ area_inter = (xmax_inter - xmin_inter) * \
188
+ (ymax_inter - ymin_inter)
189
+ area1 = (xmax1 - xmin1) * (ymax1 - ymin1)
190
+ area2 = (xmax2 - xmin2) * (ymax2 - ymin2)
191
+ iou = area_inter / (area1 + area2 - area_inter)
192
+ if iou > max_iou:
193
+ if candidates[i][1] > candidates[j][1]:
194
+ to_ignore.add(j)
195
+ else:
196
+ to_ignore.add(i)
197
+ break
198
+ else:
199
+ if area_inter / area1 > 0.9:
200
+ if candidates[i][1] < 1.1 * candidates[j][1]:
201
+ to_ignore.add(i)
202
+ if area_inter / area2 > 0.9:
203
+ if 1.1 * candidates[i][1] > candidates[j][1]:
204
+ to_ignore.add(j)
205
+ return [candidates[i][0] for i in range(len(candidates)) if i not in to_ignore]
206
+
207
+
208
+ def convert_pil_to_base64(image):
209
+ img_buffer = BytesIO()
210
+ image.save(img_buffer, format="JPEG")
211
+ byte_data = img_buffer.getvalue()
212
+ base64_str = base64.b64encode(byte_data)
213
+ return base64_str
214
+
215
+
216
+ def draw_reshape_encode(img, boxes, x, y):
217
+ boxes = [boxes.tolist()]
218
+ image = img.copy()
219
+ draw = ImageDraw.Draw(image)
220
+ new_x, new_y = int(x * HEIGHT / y), HEIGHT
221
+ for box in boxes:
222
+ print("box:", box)
223
+ draw.rectangle(
224
+ [(box[0], box[1]), (box[2], box[3])], # (x_min, y_min, x_max, y_max)
225
+ outline=color, # Box color
226
+ width=10 # Box width
227
+ )
228
+ #if x > y:
229
+ # image = image.crop((0, (x - y) / 2, x, x - (x - y) / 2))
230
+ #else:
231
+ # image = image.crop(((y - x) / 2, 0, y - (y - x) / 2, y))
232
+ return convert_pil_to_base64(image.resize((new_x, new_y)))
233
+
234
+
235
+ def get_html(url_list, encoded_images):
236
+ html = "<div style='margin-top: 20px; max-width: 1200px; display: flex; flex-wrap: wrap; justify-content: space-evenly'>"
237
+ for i in range(len(url_list)):
238
+ title, encoded = url_list[i][0], encoded_images[i]
239
+ html = (
240
+ html
241
+ + f"<img title='{escape(title)}' style='height: {HEIGHT}px; margin: 1px' src='data:image/jpeg;base64,{encoded.decode()}'>"
242
+ )
243
+ html += "</div>"
244
+ return html
245
+
246
+
247
+ description = """
248
+ # Sketch-based Detection
249
+ This app retrieves images from the [DocExplore](https://www.docexplore.eu/?lang=en) dataset based on a sketch query.
250
+ **Tip 1**: you can draw a sketch in the canvas.
251
+ **Tip 2**: you can change the size of the stroke with the slider.
252
+ The model utilized in this application is a DINOv2, which was trained in a self-supervised manner on the Flickr25k dataset.
253
+ """
254
+
255
+ div_style = {
256
+ "display": "flex",
257
+ "justify-content": "center",
258
+ "flex-wrap": "wrap",
259
+ }
260
+
261
+
262
+ def main():
263
+ st.markdown(
264
+ """
265
+ <style>
266
+ .block-container{
267
+ max-width: 1600px;
268
+ }
269
+ div.row-widget > div{
270
+ flex-direction: row;
271
+ display: flex;
272
+ justify-content: center;
273
+ }
274
+ div.row-widget.stRadio > div > label{
275
+ margin-left: 5px;
276
+ margin-right: 5px;
277
+ }
278
+ .row-widget {
279
+ margin-top: -25px;
280
+ }
281
+ section > div:first-child {
282
+ padding-top: 30px;
283
+ }
284
+ div.appview-container > section:first-child{
285
+ max-width: 320px;
286
+ }
287
+ #MainMenu {
288
+ visibility: hidden;
289
+ }
290
+ .stMarkdown {
291
+ display: grid;
292
+ place-items: center;
293
+ }
294
+ </style>
295
+ """,
296
+ unsafe_allow_html=True,
297
+ )
298
+ st.sidebar.markdown(description)
299
+
300
+ st.title("One-Shot Detection")
301
+
302
+ # Create two main columns
303
+ left_col, right_col = st.columns([0.2, 0.8]) # Adjust the weights as needed
304
+
305
+ with left_col:
306
+ # Canvas for drawing
307
+ canvas_result = st_canvas(
308
+ background_color="#eee",
309
+ stroke_width=stroke_width,
310
+ update_streamlit=True,
311
+ height=300,
312
+ width=300,
313
+ key="color_annotation_app",
314
+ )
315
+
316
+ # Input controls
317
+ query = [0]
318
+ corpus = st.radio("", ["DocExplore SAM", "DocExplore GroundingDINO"], index=0)
319
+ # score_threshold = st.slider(
320
+ # "Score threshold", min_value=0.01, max_value=1.0, value=0.5, step=0.01
321
+ # )
322
+
323
+ with right_col:
324
+ if canvas_result.image_data is not None:
325
+ draw = Image.fromarray(canvas_result.image_data.astype("uint8"))
326
+ draw = ImageOps.pad(draw.convert("RGB"), size=(224, 224))
327
+ draw.save("draw.jpg")
328
+
329
+ draw_tensor = transforms.ToTensor()(draw)
330
+ draw_tensor = transforms.Resize((224, 224))(draw_tensor)
331
+ draw_tensor = transforms.Normalize(
332
+ mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
333
+ )(draw_tensor)
334
+ draw_tensor = draw_tensor.unsqueeze(0)
335
+ else:
336
+ return
337
+
338
+ if len(query) > 0:
339
+ retrieved, bbox_of_images, dot_product = image_search(draw_tensor, corpus)
340
+ imgs, xs, ys = get_images([x[0] for x in retrieved])
341
+ encoded_images = []
342
+ for image_idx in range(len(imgs)):
343
+ img0, x, y = imgs[image_idx], xs[image_idx], ys[image_idx]
344
+ encoded_images.append(draw_reshape_encode(img0, bbox_of_images[image_idx], x, y))
345
+ st.markdown(get_html(retrieved, encoded_images), unsafe_allow_html=True)
346
+
347
+ if __name__ == "__main__":
348
+ main()
requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ numpy==1.21.5
2
+ omegaconf==2.3.0
3
+ Pillow==11.0.0
4
+ pytorch_lightning==2.4.0
5
+ scipy==1.8.0
6
+ streamlit
7
+ streamlit_drawable_canvas
8
+ torchmetrics
9
+ torchmetrics
10
+ torchvision