Spaces:
Sleeping
Sleeping
Hafiz Imam
commited on
Commit
•
0fc3879
1
Parent(s):
014ee4b
clean-up
Browse files
main.py
CHANGED
@@ -20,15 +20,15 @@ from utils import preprocess
|
|
20 |
img_paths = get_img_paths('image_data','')
|
21 |
folders = images_by_label(img_paths)
|
22 |
img_conf = get_embeddings(preprocess,model,device,img_paths,'')
|
23 |
-
|
24 |
-
|
|
|
|
|
25 |
fig = go.Figure()
|
26 |
-
fig = show_scatter(
|
27 |
|
28 |
fig.update_traces(hoverinfo="none", hovertemplate=None)
|
29 |
|
30 |
-
prompt_x = 'CEO'
|
31 |
-
prompt_y = 'doctor'
|
32 |
def serve_layout():
|
33 |
session_id = str(uuid.uuid4())
|
34 |
return update_div(prompt_x,prompt_y,fig,df,img_conf,img_paths,session_id)
|
|
|
20 |
img_paths = get_img_paths('image_data','')
|
21 |
folders = images_by_label(img_paths)
|
22 |
img_conf = get_embeddings(preprocess,model,device,img_paths,'')
|
23 |
+
|
24 |
+
prompt_x = 'CEO'
|
25 |
+
prompt_y = 'doctor'
|
26 |
+
generic_cat,df = update_plot(prompt_x, prompt_y, clip, device, model, folders, img_conf, ['indian businessman','indian businesswoman'])
|
27 |
fig = go.Figure()
|
28 |
+
fig = show_scatter(prompt_x, prompt_y,generic_cat)
|
29 |
|
30 |
fig.update_traces(hoverinfo="none", hovertemplate=None)
|
31 |
|
|
|
|
|
32 |
def serve_layout():
|
33 |
session_id = str(uuid.uuid4())
|
34 |
return update_div(prompt_x,prompt_y,fig,df,img_conf,img_paths,session_id)
|
utils.py
CHANGED
@@ -10,12 +10,10 @@ from tqdm import tqdm
|
|
10 |
import plotly.graph_objects as go
|
11 |
import pandas as pd
|
12 |
import base64
|
13 |
-
from dash_extensions.enrich import
|
14 |
-
ServersideOutputTransform
|
15 |
import dash
|
16 |
import shutil
|
17 |
import json
|
18 |
-
import uuid
|
19 |
|
20 |
def download_images(query, number_of_images, output_dir, num_threads, timeout=120):
|
21 |
"""
|
@@ -95,17 +93,6 @@ def images_by_label(img_paths):
|
|
95 |
folders[folder_name].append(img_path)
|
96 |
return folders
|
97 |
|
98 |
-
def load_image(img: str):
|
99 |
-
"""
|
100 |
-
Load an image and return it.
|
101 |
-
"""
|
102 |
-
# if isinstance(img, str):
|
103 |
-
# img = Image.open(img).convert('RGB')
|
104 |
-
# else:
|
105 |
-
# img = Image.open(io.BytesIO(img)).convert('RGB')
|
106 |
-
transformed_img = transform_image(img)[:3].unsqueeze(0)
|
107 |
-
return transformed_img
|
108 |
-
|
109 |
def compute_embedding(preprocess,model,device,img):
|
110 |
"""
|
111 |
Compute the embedding for an image and return it
|
@@ -202,31 +189,9 @@ def update_plot(prompt_x, prompt_y, clip, device, model, folders, img_conf, quer
|
|
202 |
df = pd.DataFrame.from_dict(new_dict)
|
203 |
|
204 |
return generic_cat, df
|
205 |
-
|
206 |
-
def show_histo(generic_cat):
|
207 |
-
"""
|
208 |
-
Show a histogram of similarities for male, female, and generic categories.
|
209 |
-
Args:
|
210 |
-
generic_cat (dict): A dictionary containing similarities for generic categories.
|
211 |
-
Returns:
|
212 |
-
None
|
213 |
-
"""
|
214 |
-
# Create a new figure
|
215 |
-
fig = go.Figure()
|
216 |
-
|
217 |
-
# Add histograms for generic categories
|
218 |
-
for cat, values in generic_cat.items():
|
219 |
-
fig.add_trace(go.Histogram(x=[v[0] for v in values], name=f'{cat}', histnorm='probability'))
|
220 |
-
|
221 |
-
# Overlay histograms
|
222 |
-
fig.update_layout(barmode='overlay')
|
223 |
-
# Reduce opacity to see both histograms
|
224 |
-
fig.update_traces(opacity=0.55)
|
225 |
-
|
226 |
-
# Show the figure
|
227 |
-
fig.show()
|
228 |
|
229 |
-
|
|
|
230 |
"""
|
231 |
Show a scatter plot of similarities for male, female, and generic categories.
|
232 |
Args:
|
@@ -240,10 +205,8 @@ def show_scatter(prompt, generic_cat):
|
|
240 |
fig = go.Figure()
|
241 |
for cat, values in generic_cat.items():
|
242 |
fig.add_trace(go.Scatter(x=[v[0] for v in values], y=[v[1] for v in values], mode='markers', name=f'{cat}'))
|
243 |
-
|
244 |
-
|
245 |
-
fig.update_layout(title="Similarity Scatter Plot", xaxis_title="CEO Similarity Score", yaxis_title=prompt+' similarity score')
|
246 |
-
|
247 |
return fig
|
248 |
|
249 |
def update_div(prompt_x, prompt_y, fig, df, img_conf, img_paths,session_id):
|
@@ -257,7 +220,6 @@ def update_div(prompt_x, prompt_y, fig, df, img_conf, img_paths,session_id):
|
|
257 |
dash_html_components.Div: A Div containing inputs and the graph.
|
258 |
"""
|
259 |
num_options = [1, 5, 10, 20, 50]
|
260 |
-
print(session_id)
|
261 |
return html.Div(
|
262 |
[
|
263 |
dcc.Store(id='memory-output', data={'df': df.to_json(),'img_conf':pd.DataFrame.from_dict(img_conf).to_json(),'img_paths':[i for i in img_paths]}),
|
@@ -306,39 +268,6 @@ def update_div(prompt_x, prompt_y, fig, df, img_conf, img_paths,session_id):
|
|
306 |
style={'padding': '20px', 'border': '1px solid #e0e0e0', 'border-radius': '10px'}
|
307 |
)
|
308 |
|
309 |
-
|
310 |
-
|
311 |
-
def display_images_with_preprocessing(img_paths, preprocess, device):
|
312 |
-
"""
|
313 |
-
Display original and preprocessed images side by side.
|
314 |
-
Args:
|
315 |
-
img_paths (list): List of paths to image files.
|
316 |
-
preprocess (callable): Function to preprocess the images.
|
317 |
-
device: Device on which the preprocessing will be performed.
|
318 |
-
Returns:
|
319 |
-
None
|
320 |
-
"""
|
321 |
-
# Process only the first 5 images
|
322 |
-
for img_path in img_paths:
|
323 |
-
# Preprocess the image
|
324 |
-
x = preprocess(Image.open(img_path)).unsqueeze(0).to(device)
|
325 |
-
# Convert the tensor back to a numpy array
|
326 |
-
preprocessed_image_np = x.squeeze(0).cpu().numpy().transpose((1, 2, 0))
|
327 |
-
|
328 |
-
# Plot the original and preprocessed images side by side
|
329 |
-
fig, axes = plt.subplots(1, 2, figsize=(10, 5))
|
330 |
-
axes[0].imshow(Image.open(img_path))
|
331 |
-
axes[0].set_title('Original Image')
|
332 |
-
axes[0].axis('off')
|
333 |
-
|
334 |
-
axes[1].imshow(preprocessed_image_np)
|
335 |
-
axes[1].set_title('Preprocessed Image')
|
336 |
-
axes[1].axis('off')
|
337 |
-
|
338 |
-
plt.show()
|
339 |
-
def download_images_parallel(args):
|
340 |
-
query, n_images_per_class, data_path, num_threads = args
|
341 |
-
download_images(query, n_images_per_class, data_path,num_threads)
|
342 |
def clean_up(path,session_id):
|
343 |
data_path = os.path.join(path,session_id)
|
344 |
if os.path.exists(data_path):
|
@@ -432,13 +361,7 @@ def register_callbacks(app, df, clip, device, model, folders, img_conf, queries)
|
|
432 |
img_conf = {k:[v['0'],v['1']] for k,v in img_conf.items()}
|
433 |
folders = images_by_label(img_paths)
|
434 |
generic_cat, df = update_plot(prompt_x, prompt_y, clip, device, model, folders, img_conf, [query1, query2])
|
435 |
-
|
436 |
-
fig = go.Figure()
|
437 |
-
for cat, values in generic_cat.items():
|
438 |
-
fig.add_trace(go.Scatter(x=[v[0] for v in values], y=[v[1] for v in values], mode='markers', name=f'{cat}'))
|
439 |
-
fig.update_layout(title="Similarity Scatter Plot", xaxis_title=prompt_x + " Similarity Score",
|
440 |
-
yaxis_title=prompt_y + ' Similarity score')
|
441 |
-
fig.update_traces(hoverinfo="none", hovertemplate=None)
|
442 |
data = {'df':df,'img_conf':img_conf,'img_paths':img_paths}
|
443 |
return [fig,None,Serverside(data)]
|
444 |
elif 'download-button' in changed_id:
|
@@ -447,12 +370,7 @@ def register_callbacks(app, df, clip, device, model, folders, img_conf, queries)
|
|
447 |
folders = images_by_label(img_paths)
|
448 |
img_conf = get_embeddings(preprocess,model,device,img_paths,session['session_id'])
|
449 |
generic_cat, df1 = update_plot(prompt_x, prompt_y, clip, device, model, folders, img_conf, [query1, query2])
|
450 |
-
fig =
|
451 |
-
for cat, values in generic_cat.items():
|
452 |
-
fig.add_trace(go.Scatter(x=[v[0] for v in values], y=[v[1] for v in values], mode='markers', name=f'{cat}'))
|
453 |
-
fig.update_layout(title="Similarity Scatter Plot", xaxis_title=prompt_x + " Similarity Score",
|
454 |
-
yaxis_title=prompt_y + ' Similarity score')
|
455 |
-
fig.update_traces(hoverinfo="none", hovertemplate=None)
|
456 |
clean_up('image_data',session['session_id'])
|
457 |
data = {'df':df1,'img_conf':img_conf,'img_paths':img_paths}
|
458 |
return [fig,None,Serverside(data)]
|
|
|
10 |
import plotly.graph_objects as go
|
11 |
import pandas as pd
|
12 |
import base64
|
13 |
+
from dash_extensions.enrich import Output, Input, State,Serverside, html, dcc
|
|
|
14 |
import dash
|
15 |
import shutil
|
16 |
import json
|
|
|
17 |
|
18 |
def download_images(query, number_of_images, output_dir, num_threads, timeout=120):
|
19 |
"""
|
|
|
93 |
folders[folder_name].append(img_path)
|
94 |
return folders
|
95 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
96 |
def compute_embedding(preprocess,model,device,img):
|
97 |
"""
|
98 |
Compute the embedding for an image and return it
|
|
|
189 |
df = pd.DataFrame.from_dict(new_dict)
|
190 |
|
191 |
return generic_cat, df
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
192 |
|
193 |
+
|
194 |
+
def show_scatter(prompt_x, prompt_y,generic_cat):
|
195 |
"""
|
196 |
Show a scatter plot of similarities for male, female, and generic categories.
|
197 |
Args:
|
|
|
205 |
fig = go.Figure()
|
206 |
for cat, values in generic_cat.items():
|
207 |
fig.add_trace(go.Scatter(x=[v[0] for v in values], y=[v[1] for v in values], mode='markers', name=f'{cat}'))
|
208 |
+
fig.update_layout(title="Similarity Scatter Plot", xaxis_title=f"{prompt_x} Similarity Score", yaxis_title=f"{prompt_y} Similarity Score")
|
209 |
+
fig.update_traces(hoverinfo="none", hovertemplate=None)
|
|
|
|
|
210 |
return fig
|
211 |
|
212 |
def update_div(prompt_x, prompt_y, fig, df, img_conf, img_paths,session_id):
|
|
|
220 |
dash_html_components.Div: A Div containing inputs and the graph.
|
221 |
"""
|
222 |
num_options = [1, 5, 10, 20, 50]
|
|
|
223 |
return html.Div(
|
224 |
[
|
225 |
dcc.Store(id='memory-output', data={'df': df.to_json(),'img_conf':pd.DataFrame.from_dict(img_conf).to_json(),'img_paths':[i for i in img_paths]}),
|
|
|
268 |
style={'padding': '20px', 'border': '1px solid #e0e0e0', 'border-radius': '10px'}
|
269 |
)
|
270 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
271 |
def clean_up(path,session_id):
|
272 |
data_path = os.path.join(path,session_id)
|
273 |
if os.path.exists(data_path):
|
|
|
361 |
img_conf = {k:[v['0'],v['1']] for k,v in img_conf.items()}
|
362 |
folders = images_by_label(img_paths)
|
363 |
generic_cat, df = update_plot(prompt_x, prompt_y, clip, device, model, folders, img_conf, [query1, query2])
|
364 |
+
fig = show_scatter(prompt_x, prompt_y,generic_cat)
|
|
|
|
|
|
|
|
|
|
|
|
|
365 |
data = {'df':df,'img_conf':img_conf,'img_paths':img_paths}
|
366 |
return [fig,None,Serverside(data)]
|
367 |
elif 'download-button' in changed_id:
|
|
|
370 |
folders = images_by_label(img_paths)
|
371 |
img_conf = get_embeddings(preprocess,model,device,img_paths,session['session_id'])
|
372 |
generic_cat, df1 = update_plot(prompt_x, prompt_y, clip, device, model, folders, img_conf, [query1, query2])
|
373 |
+
fig = show_scatter(prompt_x, prompt_y,generic_cat)
|
|
|
|
|
|
|
|
|
|
|
374 |
clean_up('image_data',session['session_id'])
|
375 |
data = {'df':df1,'img_conf':img_conf,'img_paths':img_paths}
|
376 |
return [fig,None,Serverside(data)]
|