Se utiliza el dataset desde hugginface
Browse files- __pycache__/utils.cpython-310.pyc +0 -0
- app.py +68 -120
- src/__pycache__/model_LN_prompt.cpython-310.pyc +0 -0
- src/__pycache__/options.cpython-310.pyc +0 -0
- src/model_LN_prompt.py +0 -18
- src/options.py +4 -5
__pycache__/utils.cpython-310.pyc
ADDED
Binary file (3.75 kB). View file
|
|
app.py
CHANGED
@@ -1,90 +1,85 @@
|
|
1 |
import os
|
2 |
-
|
3 |
import streamlit as st
|
4 |
from io import BytesIO
|
5 |
-
import base64
|
6 |
from multiprocessing.dummy import Pool
|
7 |
-
|
8 |
-
|
9 |
import torch
|
10 |
from torchvision import transforms
|
11 |
-
|
12 |
-
# sketches
|
13 |
from streamlit_drawable_canvas import st_canvas
|
14 |
from src.model_LN_prompt import Model
|
15 |
-
|
16 |
-
|
17 |
-
import pickle as pkl
|
18 |
from html import escape
|
|
|
19 |
from huggingface_hub import hf_hub_download, login
|
20 |
from datasets import load_dataset
|
21 |
|
22 |
-
token = os.getenv("HUGGINGFACE_TOKEN")
|
23 |
|
24 |
-
|
25 |
-
|
26 |
|
27 |
-
# Variables
|
28 |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
29 |
-
print(f"Device: {device}")
|
30 |
HEIGHT = 200
|
31 |
-
N_RESULTS =
|
32 |
color = st.get_option("theme.primaryColor")
|
33 |
if color is None:
|
34 |
color = (0, 0, 255)
|
35 |
else:
|
36 |
color = tuple(int(color.lstrip("#")[i: i + 2], 16) for i in (0, 2, 4))
|
37 |
|
|
|
38 |
@st.cache_resource
|
39 |
-
def
|
40 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
41 |
dataset = load_dataset("CHSTR/ecommerce")
|
42 |
path_images = "/".join(dataset['validation']
|
43 |
['image'][0].filename.split("/")[:-3]) + "/"
|
44 |
-
print(f"Directorio de imágenes: {path_images}")
|
45 |
|
46 |
-
#
|
47 |
path_model = hf_hub_download(
|
48 |
repo_id="CHSTR/Ecommerce", filename="dinov2_ecommerce.ckpt")
|
49 |
-
print(f"Archivo del modelo descargado en: {path_model}")
|
50 |
|
51 |
-
#
|
52 |
-
model = Model()
|
53 |
model_checkpoint = torch.load(path_model, map_location=device)
|
54 |
model.load_state_dict(model_checkpoint['state_dict'])
|
55 |
model.eval()
|
56 |
-
# model.to(device)
|
57 |
-
print("Modelo cargado exitosamente")
|
58 |
|
59 |
-
#
|
60 |
embeddings_file = hf_hub_download(
|
61 |
repo_id="CHSTR/Ecommerce", filename="ecommerce_demo.pkl")
|
62 |
-
print(f"Archivo de embeddings descargado en: {embeddings_file}")
|
63 |
|
64 |
embeddings = {
|
65 |
0: pkl.load(open(embeddings_file, "rb")),
|
66 |
1: pkl.load(open(embeddings_file, "rb"))
|
67 |
}
|
68 |
|
69 |
-
#
|
70 |
-
for
|
71 |
-
embeddings[
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
for i in range(len(embeddings[1])):
|
76 |
-
embeddings[1][i] = (embeddings[1][i][0], path_images +
|
77 |
-
"/".join(embeddings[1][i][1].split("/")[-3:]))
|
78 |
|
79 |
return model, path_images, embeddings
|
80 |
|
81 |
-
|
|
|
82 |
with torch.no_grad():
|
83 |
-
sketch_feat = model(
|
84 |
return sketch_feat
|
85 |
|
86 |
-
|
87 |
-
|
|
|
88 |
corpus_id = 0 if corpus == "Unsplash" else 1
|
89 |
image_features = torch.tensor(
|
90 |
[item[0] for item in embeddings[corpus_id]]).to(device)
|
@@ -93,7 +88,6 @@ def image_search(query, corpus, n_results=N_RESULTS):
|
|
93 |
_, max_indices = torch.topk(
|
94 |
dot_product, n_results, dim=0, largest=True, sorted=True)
|
95 |
|
96 |
-
# Diccionario para mapear los paths a labels
|
97 |
path_to_label = {path: idx for idx,
|
98 |
(_, path) in enumerate(embeddings[corpus_id])}
|
99 |
label_to_path = {idx: path for path, idx in path_to_label.items()}
|
@@ -101,14 +95,14 @@ def image_search(query, corpus, n_results=N_RESULTS):
|
|
101 |
[path_to_label[item[1]] for item in embeddings[corpus_id]]).to(device)
|
102 |
|
103 |
return [
|
104 |
-
(
|
105 |
-
label_to_path[i],
|
106 |
-
)
|
107 |
for i in label_of_images[max_indices].cpu().numpy().tolist()
|
108 |
-
], dot_product[max_indices]
|
109 |
|
110 |
|
111 |
-
|
|
|
|
|
112 |
x, y = img.size
|
113 |
size = max(x, y)
|
114 |
new_img = Image.new("RGB", (x, y), fill_color)
|
@@ -118,18 +112,12 @@ def make_square(img, fill_color=(255, 255, 255)):
|
|
118 |
|
119 |
@st.cache_data
|
120 |
def get_images(paths):
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
processed = Pool(N_RESULTS).map(process_image, paths)
|
125 |
-
imgs, xs, ys = [], [], []
|
126 |
-
for img, x, y in processed:
|
127 |
-
imgs.append(img)
|
128 |
-
xs.append(x)
|
129 |
-
ys.append(y)
|
130 |
-
return imgs, xs, ys
|
131 |
|
132 |
|
|
|
133 |
def convert_pil_to_base64(image):
|
134 |
img_buffer = BytesIO()
|
135 |
image.save(img_buffer, format="JPEG")
|
@@ -138,21 +126,6 @@ def convert_pil_to_base64(image):
|
|
138 |
return base64_str
|
139 |
|
140 |
|
141 |
-
def draw_reshape_encode(img, boxes, x, y):
|
142 |
-
boxes = [boxes.tolist()]
|
143 |
-
image = img.copy()
|
144 |
-
draw = ImageDraw.Draw(image)
|
145 |
-
new_x, new_y = int(x * HEIGHT / y), HEIGHT
|
146 |
-
for box in boxes:
|
147 |
-
print("box:", box)
|
148 |
-
draw.rectangle(
|
149 |
-
# (x_min, y_min, x_max, y_max)
|
150 |
-
[(box[0], box[1]), (box[2], box[3])],
|
151 |
-
outline=color, # Box color
|
152 |
-
width=7 # Box width
|
153 |
-
)
|
154 |
-
|
155 |
-
|
156 |
def get_html(url_list, encoded_images):
|
157 |
html = "<div style='margin-top: 20px; max-width: 1200px; display: flex; flex-wrap: wrap; justify-content: space-evenly'>"
|
158 |
for i in range(len(url_list)):
|
@@ -165,63 +138,40 @@ def get_html(url_list, encoded_images):
|
|
165 |
return html
|
166 |
|
167 |
|
168 |
-
|
169 |
-
|
170 |
-
|
171 |
-
|
172 |
-
|
173 |
-
"display": "flex",
|
174 |
-
"justify-content": "center",
|
175 |
-
"flex-wrap": "wrap",
|
176 |
-
}
|
177 |
-
|
178 |
|
179 |
-
|
|
|
180 |
|
|
|
181 |
|
182 |
-
|
183 |
|
184 |
-
|
185 |
stroke_width = st.sidebar.slider("Stroke width: ", 1, 25, 5)
|
186 |
|
|
|
187 |
st.markdown(
|
188 |
"""
|
189 |
<style>
|
190 |
-
.block-container{
|
191 |
-
|
192 |
-
}
|
193 |
-
|
194 |
-
|
195 |
-
|
196 |
-
|
197 |
-
}
|
198 |
-
div.row-widget.stRadio > div > label{
|
199 |
-
margin-left: 5px;
|
200 |
-
margin-right: 5px;
|
201 |
-
}
|
202 |
-
.row-widget {
|
203 |
-
margin-top: -25px;
|
204 |
-
}
|
205 |
-
section > div:first-child {
|
206 |
-
padding-top: 30px;
|
207 |
-
}
|
208 |
-
div.appview-container > section:first-child{
|
209 |
-
max-width: 320px;
|
210 |
-
}
|
211 |
-
#MainMenu {
|
212 |
-
visibility: hidden;
|
213 |
-
}
|
214 |
-
.stMarkdown {
|
215 |
-
display: grid;
|
216 |
-
place-items: center;
|
217 |
-
}
|
218 |
</style>
|
219 |
""",
|
220 |
unsafe_allow_html=True,
|
221 |
)
|
222 |
-
st.sidebar.markdown(description)
|
223 |
|
224 |
-
st.title("
|
225 |
_, col, _ = st.columns((1, 1, 1))
|
226 |
with col:
|
227 |
canvas_result = st_canvas(
|
@@ -233,13 +183,12 @@ def main():
|
|
233 |
key="color_annotation_app",
|
234 |
)
|
235 |
|
236 |
-
st.columns((1, 3, 1))
|
237 |
corpus = ["Ecommerce"]
|
|
|
238 |
|
239 |
if canvas_result.image_data is not None:
|
240 |
draw = Image.fromarray(canvas_result.image_data.astype("uint8"))
|
241 |
draw = ImageOps.pad(draw.convert("RGB"), size=(224, 224))
|
242 |
-
draw.save("draw.jpg")
|
243 |
|
244 |
draw_tensor = transforms.ToTensor()(draw)
|
245 |
draw_tensor = transforms.Resize((224, 224))(draw_tensor)
|
@@ -248,20 +197,19 @@ def main():
|
|
248 |
)(draw_tensor)
|
249 |
draw_tensor = draw_tensor.unsqueeze(0)
|
250 |
|
251 |
-
retrieved, _ = image_search(
|
|
|
252 |
imgs, xs, ys = get_images([x[0] for x in retrieved])
|
|
|
253 |
encoded_images = []
|
254 |
for image_idx in range(len(imgs)):
|
255 |
img0, x, y = imgs[image_idx], xs[image_idx], ys[image_idx]
|
256 |
-
|
257 |
new_x, new_y = int(x * HEIGHT / y), HEIGHT
|
258 |
-
|
259 |
encoded_images.append(convert_pil_to_base64(
|
260 |
img0.resize((new_x, new_y))))
|
|
|
261 |
st.markdown(get_html(retrieved, encoded_images),
|
262 |
unsafe_allow_html=True)
|
263 |
-
else:
|
264 |
-
return
|
265 |
|
266 |
|
267 |
if __name__ == "__main__":
|
|
|
1 |
import os
|
|
|
2 |
import streamlit as st
|
3 |
from io import BytesIO
|
|
|
4 |
from multiprocessing.dummy import Pool
|
5 |
+
import base64
|
6 |
+
from PIL import Image, ImageOps
|
7 |
import torch
|
8 |
from torchvision import transforms
|
|
|
|
|
9 |
from streamlit_drawable_canvas import st_canvas
|
10 |
from src.model_LN_prompt import Model
|
|
|
|
|
|
|
11 |
from html import escape
|
12 |
+
import pickle as pkl
|
13 |
from huggingface_hub import hf_hub_download, login
|
14 |
from datasets import load_dataset
|
15 |
|
|
|
16 |
|
17 |
+
if 'initialized' not in st.session_state:
|
18 |
+
st.session_state.initialized = False
|
19 |
|
|
|
20 |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
|
|
21 |
HEIGHT = 200
|
22 |
+
N_RESULTS = 20
|
23 |
color = st.get_option("theme.primaryColor")
|
24 |
if color is None:
|
25 |
color = (0, 0, 255)
|
26 |
else:
|
27 |
color = tuple(int(color.lstrip("#")[i: i + 2], 16) for i in (0, 2, 4))
|
28 |
|
29 |
+
|
30 |
@st.cache_resource
|
31 |
+
def initialize_huggingface():
|
32 |
+
token = os.getenv("HUGGINGFACE_TOKEN")
|
33 |
+
if token:
|
34 |
+
login(token=token)
|
35 |
+
else:
|
36 |
+
st.error("HUGGINGFACE_TOKEN not found in environment variables")
|
37 |
+
|
38 |
+
|
39 |
+
@st.cache_resource
|
40 |
+
def load_model_and_data():
|
41 |
+
print("Loading everything...")
|
42 |
dataset = load_dataset("CHSTR/ecommerce")
|
43 |
path_images = "/".join(dataset['validation']
|
44 |
['image'][0].filename.split("/")[:-3]) + "/"
|
|
|
45 |
|
46 |
+
# Download model
|
47 |
path_model = hf_hub_download(
|
48 |
repo_id="CHSTR/Ecommerce", filename="dinov2_ecommerce.ckpt")
|
|
|
49 |
|
50 |
+
# Load model
|
51 |
+
model = Model().to(device)
|
52 |
model_checkpoint = torch.load(path_model, map_location=device)
|
53 |
model.load_state_dict(model_checkpoint['state_dict'])
|
54 |
model.eval()
|
|
|
|
|
55 |
|
56 |
+
# Download and load embeddings
|
57 |
embeddings_file = hf_hub_download(
|
58 |
repo_id="CHSTR/Ecommerce", filename="ecommerce_demo.pkl")
|
|
|
59 |
|
60 |
embeddings = {
|
61 |
0: pkl.load(open(embeddings_file, "rb")),
|
62 |
1: pkl.load(open(embeddings_file, "rb"))
|
63 |
}
|
64 |
|
65 |
+
# Update image paths
|
66 |
+
for corpus_id in [0, 1]:
|
67 |
+
embeddings[corpus_id] = [
|
68 |
+
(emb[0], path_images + "/".join(emb[1].split("/")[-3:]))
|
69 |
+
for emb in embeddings[corpus_id]
|
70 |
+
]
|
|
|
|
|
|
|
71 |
|
72 |
return model, path_images, embeddings
|
73 |
|
74 |
+
|
75 |
+
def compute_sketch(_sketch, model):
|
76 |
with torch.no_grad():
|
77 |
+
sketch_feat = model(_sketch.to(device), dtype='sketch')
|
78 |
return sketch_feat
|
79 |
|
80 |
+
|
81 |
+
def image_search(_query, corpus, model, embeddings, n_results=N_RESULTS):
|
82 |
+
query_embedding = compute_sketch(_query, model)
|
83 |
corpus_id = 0 if corpus == "Unsplash" else 1
|
84 |
image_features = torch.tensor(
|
85 |
[item[0] for item in embeddings[corpus_id]]).to(device)
|
|
|
88 |
_, max_indices = torch.topk(
|
89 |
dot_product, n_results, dim=0, largest=True, sorted=True)
|
90 |
|
|
|
91 |
path_to_label = {path: idx for idx,
|
92 |
(_, path) in enumerate(embeddings[corpus_id])}
|
93 |
label_to_path = {idx: path for path, idx in path_to_label.items()}
|
|
|
95 |
[path_to_label[item[1]] for item in embeddings[corpus_id]]).to(device)
|
96 |
|
97 |
return [
|
98 |
+
(label_to_path[i],)
|
|
|
|
|
99 |
for i in label_of_images[max_indices].cpu().numpy().tolist()
|
100 |
+
], dot_product[max_indices]
|
101 |
|
102 |
|
103 |
+
@st.cache_data
|
104 |
+
def make_square(img_path, fill_color=(255, 255, 255)):
|
105 |
+
img = Image.open(img_path)
|
106 |
x, y = img.size
|
107 |
size = max(x, y)
|
108 |
new_img = Image.new("RGB", (x, y), fill_color)
|
|
|
112 |
|
113 |
@st.cache_data
|
114 |
def get_images(paths):
|
115 |
+
processed = [make_square(path) for path in paths]
|
116 |
+
imgs, xs, ys = zip(*processed)
|
117 |
+
return list(imgs), list(xs), list(ys)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
118 |
|
119 |
|
120 |
+
@st.cache_data
|
121 |
def convert_pil_to_base64(image):
|
122 |
img_buffer = BytesIO()
|
123 |
image.save(img_buffer, format="JPEG")
|
|
|
126 |
return base64_str
|
127 |
|
128 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
129 |
def get_html(url_list, encoded_images):
|
130 |
html = "<div style='margin-top: 20px; max-width: 1200px; display: flex; flex-wrap: wrap; justify-content: space-evenly'>"
|
131 |
for i in range(len(url_list)):
|
|
|
138 |
return html
|
139 |
|
140 |
|
141 |
+
def main():
|
142 |
+
if not st.session_state.initialized:
|
143 |
+
initialize_huggingface()
|
144 |
+
st.session_state.model, st.session_state.path_images, st.session_state.embeddings = load_model_and_data()
|
145 |
+
st.session_state.initialized = True
|
|
|
|
|
|
|
|
|
|
|
146 |
|
147 |
+
description = """
|
148 |
+
# Self-Supervised Sketch-based Image Retrieval (S3BIR)
|
149 |
|
150 |
+
Our approaches, S3BIR-CLIP and S3BIR-DINOv2, can produce a bimodal sketch-photo feature space from unpaired data without explicit sketch-photo pairs. Our experiments perform outstandingly in three diverse public datasets where the models are trained without real sketches.
|
151 |
|
152 |
+
"""
|
153 |
|
154 |
+
st.sidebar.markdown(description)
|
155 |
stroke_width = st.sidebar.slider("Stroke width: ", 1, 25, 5)
|
156 |
|
157 |
+
# styles
|
158 |
st.markdown(
|
159 |
"""
|
160 |
<style>
|
161 |
+
.block-container{ max-width: 1200px; }
|
162 |
+
div.row-widget > div{ flex-direction: row; display: flex; justify-content: center; color: white; }
|
163 |
+
div.row-widget.stRadio > div > label{ margin-left: 5px; margin-right: 5px; }
|
164 |
+
.row-widget { margin-top: -25px; }
|
165 |
+
section > div:first-child { padding-top: 30px; }
|
166 |
+
div.appview-container > section:first-child{ max-width: 320px; }
|
167 |
+
#MainMenu { visibility: hidden; }
|
168 |
+
.stMarkdown { display: grid; place-items: center; }
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
169 |
</style>
|
170 |
""",
|
171 |
unsafe_allow_html=True,
|
172 |
)
|
|
|
173 |
|
174 |
+
st.title("S3BIR App")
|
175 |
_, col, _ = st.columns((1, 1, 1))
|
176 |
with col:
|
177 |
canvas_result = st_canvas(
|
|
|
183 |
key="color_annotation_app",
|
184 |
)
|
185 |
|
|
|
186 |
corpus = ["Ecommerce"]
|
187 |
+
st.columns((1, 3, 1))
|
188 |
|
189 |
if canvas_result.image_data is not None:
|
190 |
draw = Image.fromarray(canvas_result.image_data.astype("uint8"))
|
191 |
draw = ImageOps.pad(draw.convert("RGB"), size=(224, 224))
|
|
|
192 |
|
193 |
draw_tensor = transforms.ToTensor()(draw)
|
194 |
draw_tensor = transforms.Resize((224, 224))(draw_tensor)
|
|
|
197 |
)(draw_tensor)
|
198 |
draw_tensor = draw_tensor.unsqueeze(0)
|
199 |
|
200 |
+
retrieved, _ = image_search(
|
201 |
+
draw_tensor, corpus[0], st.session_state.model, st.session_state.embeddings)
|
202 |
imgs, xs, ys = get_images([x[0] for x in retrieved])
|
203 |
+
|
204 |
encoded_images = []
|
205 |
for image_idx in range(len(imgs)):
|
206 |
img0, x, y = imgs[image_idx], xs[image_idx], ys[image_idx]
|
|
|
207 |
new_x, new_y = int(x * HEIGHT / y), HEIGHT
|
|
|
208 |
encoded_images.append(convert_pil_to_base64(
|
209 |
img0.resize((new_x, new_y))))
|
210 |
+
|
211 |
st.markdown(get_html(retrieved, encoded_images),
|
212 |
unsafe_allow_html=True)
|
|
|
|
|
213 |
|
214 |
|
215 |
if __name__ == "__main__":
|
src/__pycache__/model_LN_prompt.cpython-310.pyc
CHANGED
Binary files a/src/__pycache__/model_LN_prompt.cpython-310.pyc and b/src/__pycache__/model_LN_prompt.cpython-310.pyc differ
|
|
src/__pycache__/options.cpython-310.pyc
CHANGED
Binary files a/src/__pycache__/options.cpython-310.pyc and b/src/__pycache__/options.cpython-310.pyc differ
|
|
src/model_LN_prompt.py
CHANGED
@@ -1,15 +1,9 @@
|
|
1 |
-
import numpy as np
|
2 |
import torch
|
3 |
import torch.nn as nn
|
4 |
import torch.nn.functional as F
|
5 |
-
from torchmetrics.functional import retrieval_average_precision
|
6 |
import pytorch_lightning as pl
|
7 |
|
8 |
from src.dinov2.models.vision_transformer import vit_base
|
9 |
-
|
10 |
-
from functools import partial
|
11 |
-
|
12 |
-
# from src.clip import clip
|
13 |
from src.options import opts
|
14 |
|
15 |
def freeze_model(m):
|
@@ -31,23 +25,11 @@ class Model(pl.LightningModule):
|
|
31 |
self.opts = opts
|
32 |
|
33 |
self.dino = vit_base(patch_size=14, block_chunks=0, init_values=1.0)
|
34 |
-
print("self.dino", self.dino)
|
35 |
|
36 |
# Prompt Engineering
|
37 |
self.sk_prompt = nn.Parameter(torch.randn(self.opts.n_prompts, self.opts.prompt_dim))
|
38 |
self.img_prompt = nn.Parameter(torch.randn(self.opts.n_prompts, self.opts.prompt_dim))
|
39 |
|
40 |
-
self.distance_fn = lambda x, y: 1.0 - F.cosine_similarity(x, y)
|
41 |
-
self.loss_fn_triplet = nn.TripletMarginWithDistanceLoss(
|
42 |
-
distance_function=self.distance_fn, margin=0.2)
|
43 |
-
|
44 |
-
self.emb_cos_loss = nn.CosineEmbeddingLoss(margin=0.2)
|
45 |
-
|
46 |
-
self.loss_kl = nn.KLDivLoss(reduction="batchmean", log_target=True)
|
47 |
-
|
48 |
-
self.best_metric = -1e3
|
49 |
-
# normalization layer for the representations z1 and z2
|
50 |
-
# self.bn = nn.BatchNorm1d(self.opts.prompt_dim, affine=False)
|
51 |
|
52 |
def configure_optimizers(self):
|
53 |
if self.opts.model_type == 'one_encoder':
|
|
|
|
|
1 |
import torch
|
2 |
import torch.nn as nn
|
3 |
import torch.nn.functional as F
|
|
|
4 |
import pytorch_lightning as pl
|
5 |
|
6 |
from src.dinov2.models.vision_transformer import vit_base
|
|
|
|
|
|
|
|
|
7 |
from src.options import opts
|
8 |
|
9 |
def freeze_model(m):
|
|
|
25 |
self.opts = opts
|
26 |
|
27 |
self.dino = vit_base(patch_size=14, block_chunks=0, init_values=1.0)
|
|
|
28 |
|
29 |
# Prompt Engineering
|
30 |
self.sk_prompt = nn.Parameter(torch.randn(self.opts.n_prompts, self.opts.prompt_dim))
|
31 |
self.img_prompt = nn.Parameter(torch.randn(self.opts.n_prompts, self.opts.prompt_dim))
|
32 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
|
34 |
def configure_optimizers(self):
|
35 |
if self.opts.model_type == 'one_encoder':
|
src/options.py
CHANGED
@@ -1,18 +1,17 @@
|
|
1 |
import argparse
|
2 |
|
3 |
-
parser = argparse.ArgumentParser(description='
|
4 |
|
5 |
-
parser.add_argument('--exp_name', type=str, default='
|
6 |
|
7 |
# ----------------------
|
8 |
# Training Params
|
9 |
# ----------------------
|
10 |
|
11 |
-
parser.add_argument('--
|
12 |
-
parser.add_argument('--
|
13 |
parser.add_argument('--prompt_lr', type=float, default=1e-4)
|
14 |
parser.add_argument('--linear_lr', type=float, default=1e-4)
|
15 |
-
parser.add_argument('--model_type', type=str, default='one_encoder', choices=['one_encoder', 'two_encoder'])
|
16 |
|
17 |
# ----------------------
|
18 |
# ViT Prompt Parameters
|
|
|
1 |
import argparse
|
2 |
|
3 |
+
parser = argparse.ArgumentParser(description='S3BIR')
|
4 |
|
5 |
+
parser.add_argument('--exp_name', type=str, default='DINOv2_prompt')
|
6 |
|
7 |
# ----------------------
|
8 |
# Training Params
|
9 |
# ----------------------
|
10 |
|
11 |
+
parser.add_argument('--dinov2_lr', type=float, default=1e-4)
|
12 |
+
parser.add_argument('--dinov2_LN_lr', type=float, default=1e-6)
|
13 |
parser.add_argument('--prompt_lr', type=float, default=1e-4)
|
14 |
parser.add_argument('--linear_lr', type=float, default=1e-4)
|
|
|
15 |
|
16 |
# ----------------------
|
17 |
# ViT Prompt Parameters
|