Hafiz Imam commited on
Commit
0fc3879
1 Parent(s): 014ee4b
Files changed (2) hide show
  1. main.py +5 -5
  2. utils.py +7 -89
main.py CHANGED
@@ -20,15 +20,15 @@ from utils import preprocess
20
  img_paths = get_img_paths('image_data','')
21
  folders = images_by_label(img_paths)
22
  img_conf = get_embeddings(preprocess,model,device,img_paths,'')
23
- prompt = 'doctor'
24
- generic_cat,df = update_plot('CEO', prompt, clip, device, model, folders, img_conf, ['indian businessman','indian businesswoman'])
 
 
25
  fig = go.Figure()
26
- fig = show_scatter(prompt,generic_cat)
27
 
28
  fig.update_traces(hoverinfo="none", hovertemplate=None)
29
 
30
- prompt_x = 'CEO'
31
- prompt_y = 'doctor'
32
  def serve_layout():
33
  session_id = str(uuid.uuid4())
34
  return update_div(prompt_x,prompt_y,fig,df,img_conf,img_paths,session_id)
 
20
  img_paths = get_img_paths('image_data','')
21
  folders = images_by_label(img_paths)
22
  img_conf = get_embeddings(preprocess,model,device,img_paths,'')
23
+
24
+ prompt_x = 'CEO'
25
+ prompt_y = 'doctor'
26
+ generic_cat,df = update_plot(prompt_x, prompt_y, clip, device, model, folders, img_conf, ['indian businessman','indian businesswoman'])
27
  fig = go.Figure()
28
+ fig = show_scatter(prompt_x, prompt_y,generic_cat)
29
 
30
  fig.update_traces(hoverinfo="none", hovertemplate=None)
31
 
 
 
32
  def serve_layout():
33
  session_id = str(uuid.uuid4())
34
  return update_div(prompt_x,prompt_y,fig,df,img_conf,img_paths,session_id)
utils.py CHANGED
@@ -10,12 +10,10 @@ from tqdm import tqdm
10
  import plotly.graph_objects as go
11
  import pandas as pd
12
  import base64
13
- from dash_extensions.enrich import DashProxy, Output, Input, State,no_update, Serverside, html, dcc, \
14
- ServersideOutputTransform
15
  import dash
16
  import shutil
17
  import json
18
- import uuid
19
 
20
  def download_images(query, number_of_images, output_dir, num_threads, timeout=120):
21
  """
@@ -95,17 +93,6 @@ def images_by_label(img_paths):
95
  folders[folder_name].append(img_path)
96
  return folders
97
 
98
- def load_image(img: str):
99
- """
100
- Load an image and return it.
101
- """
102
- # if isinstance(img, str):
103
- # img = Image.open(img).convert('RGB')
104
- # else:
105
- # img = Image.open(io.BytesIO(img)).convert('RGB')
106
- transformed_img = transform_image(img)[:3].unsqueeze(0)
107
- return transformed_img
108
-
109
  def compute_embedding(preprocess,model,device,img):
110
  """
111
  Compute the embedding for an image and return it
@@ -202,31 +189,9 @@ def update_plot(prompt_x, prompt_y, clip, device, model, folders, img_conf, quer
202
  df = pd.DataFrame.from_dict(new_dict)
203
 
204
  return generic_cat, df
205
-
206
- def show_histo(generic_cat):
207
- """
208
- Show a histogram of similarities for male, female, and generic categories.
209
- Args:
210
- generic_cat (dict): A dictionary containing similarities for generic categories.
211
- Returns:
212
- None
213
- """
214
- # Create a new figure
215
- fig = go.Figure()
216
-
217
- # Add histograms for generic categories
218
- for cat, values in generic_cat.items():
219
- fig.add_trace(go.Histogram(x=[v[0] for v in values], name=f'{cat}', histnorm='probability'))
220
-
221
- # Overlay histograms
222
- fig.update_layout(barmode='overlay')
223
- # Reduce opacity to see both histograms
224
- fig.update_traces(opacity=0.55)
225
-
226
- # Show the figure
227
- fig.show()
228
 
229
- def show_scatter(prompt, generic_cat):
 
230
  """
231
  Show a scatter plot of similarities for male, female, and generic categories.
232
  Args:
@@ -240,10 +205,8 @@ def show_scatter(prompt, generic_cat):
240
  fig = go.Figure()
241
  for cat, values in generic_cat.items():
242
  fig.add_trace(go.Scatter(x=[v[0] for v in values], y=[v[1] for v in values], mode='markers', name=f'{cat}'))
243
-
244
- # Update layout
245
- fig.update_layout(title="Similarity Scatter Plot", xaxis_title="CEO Similarity Score", yaxis_title=prompt+' similarity score')
246
-
247
  return fig
248
 
249
  def update_div(prompt_x, prompt_y, fig, df, img_conf, img_paths,session_id):
@@ -257,7 +220,6 @@ def update_div(prompt_x, prompt_y, fig, df, img_conf, img_paths,session_id):
257
  dash_html_components.Div: A Div containing inputs and the graph.
258
  """
259
  num_options = [1, 5, 10, 20, 50]
260
- print(session_id)
261
  return html.Div(
262
  [
263
  dcc.Store(id='memory-output', data={'df': df.to_json(),'img_conf':pd.DataFrame.from_dict(img_conf).to_json(),'img_paths':[i for i in img_paths]}),
@@ -306,39 +268,6 @@ def update_div(prompt_x, prompt_y, fig, df, img_conf, img_paths,session_id):
306
  style={'padding': '20px', 'border': '1px solid #e0e0e0', 'border-radius': '10px'}
307
  )
308
 
309
-
310
-
311
- def display_images_with_preprocessing(img_paths, preprocess, device):
312
- """
313
- Display original and preprocessed images side by side.
314
- Args:
315
- img_paths (list): List of paths to image files.
316
- preprocess (callable): Function to preprocess the images.
317
- device: Device on which the preprocessing will be performed.
318
- Returns:
319
- None
320
- """
321
- # Process only the first 5 images
322
- for img_path in img_paths:
323
- # Preprocess the image
324
- x = preprocess(Image.open(img_path)).unsqueeze(0).to(device)
325
- # Convert the tensor back to a numpy array
326
- preprocessed_image_np = x.squeeze(0).cpu().numpy().transpose((1, 2, 0))
327
-
328
- # Plot the original and preprocessed images side by side
329
- fig, axes = plt.subplots(1, 2, figsize=(10, 5))
330
- axes[0].imshow(Image.open(img_path))
331
- axes[0].set_title('Original Image')
332
- axes[0].axis('off')
333
-
334
- axes[1].imshow(preprocessed_image_np)
335
- axes[1].set_title('Preprocessed Image')
336
- axes[1].axis('off')
337
-
338
- plt.show()
339
- def download_images_parallel(args):
340
- query, n_images_per_class, data_path, num_threads = args
341
- download_images(query, n_images_per_class, data_path,num_threads)
342
  def clean_up(path,session_id):
343
  data_path = os.path.join(path,session_id)
344
  if os.path.exists(data_path):
@@ -432,13 +361,7 @@ def register_callbacks(app, df, clip, device, model, folders, img_conf, queries)
432
  img_conf = {k:[v['0'],v['1']] for k,v in img_conf.items()}
433
  folders = images_by_label(img_paths)
434
  generic_cat, df = update_plot(prompt_x, prompt_y, clip, device, model, folders, img_conf, [query1, query2])
435
-
436
- fig = go.Figure()
437
- for cat, values in generic_cat.items():
438
- fig.add_trace(go.Scatter(x=[v[0] for v in values], y=[v[1] for v in values], mode='markers', name=f'{cat}'))
439
- fig.update_layout(title="Similarity Scatter Plot", xaxis_title=prompt_x + " Similarity Score",
440
- yaxis_title=prompt_y + ' Similarity score')
441
- fig.update_traces(hoverinfo="none", hovertemplate=None)
442
  data = {'df':df,'img_conf':img_conf,'img_paths':img_paths}
443
  return [fig,None,Serverside(data)]
444
  elif 'download-button' in changed_id:
@@ -447,12 +370,7 @@ def register_callbacks(app, df, clip, device, model, folders, img_conf, queries)
447
  folders = images_by_label(img_paths)
448
  img_conf = get_embeddings(preprocess,model,device,img_paths,session['session_id'])
449
  generic_cat, df1 = update_plot(prompt_x, prompt_y, clip, device, model, folders, img_conf, [query1, query2])
450
- fig = go.Figure()
451
- for cat, values in generic_cat.items():
452
- fig.add_trace(go.Scatter(x=[v[0] for v in values], y=[v[1] for v in values], mode='markers', name=f'{cat}'))
453
- fig.update_layout(title="Similarity Scatter Plot", xaxis_title=prompt_x + " Similarity Score",
454
- yaxis_title=prompt_y + ' Similarity score')
455
- fig.update_traces(hoverinfo="none", hovertemplate=None)
456
  clean_up('image_data',session['session_id'])
457
  data = {'df':df1,'img_conf':img_conf,'img_paths':img_paths}
458
  return [fig,None,Serverside(data)]
 
10
  import plotly.graph_objects as go
11
  import pandas as pd
12
  import base64
13
+ from dash_extensions.enrich import Output, Input, State,Serverside, html, dcc
 
14
  import dash
15
  import shutil
16
  import json
 
17
 
18
  def download_images(query, number_of_images, output_dir, num_threads, timeout=120):
19
  """
 
93
  folders[folder_name].append(img_path)
94
  return folders
95
 
 
 
 
 
 
 
 
 
 
 
 
96
  def compute_embedding(preprocess,model,device,img):
97
  """
98
  Compute the embedding for an image and return it
 
189
  df = pd.DataFrame.from_dict(new_dict)
190
 
191
  return generic_cat, df
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
192
 
193
+
194
+ def show_scatter(prompt_x, prompt_y,generic_cat):
195
  """
196
  Show a scatter plot of similarities for male, female, and generic categories.
197
  Args:
 
205
  fig = go.Figure()
206
  for cat, values in generic_cat.items():
207
  fig.add_trace(go.Scatter(x=[v[0] for v in values], y=[v[1] for v in values], mode='markers', name=f'{cat}'))
208
+ fig.update_layout(title="Similarity Scatter Plot", xaxis_title=f"{prompt_x} Similarity Score", yaxis_title=f"{prompt_y} Similarity Score")
209
+ fig.update_traces(hoverinfo="none", hovertemplate=None)
 
 
210
  return fig
211
 
212
  def update_div(prompt_x, prompt_y, fig, df, img_conf, img_paths,session_id):
 
220
  dash_html_components.Div: A Div containing inputs and the graph.
221
  """
222
  num_options = [1, 5, 10, 20, 50]
 
223
  return html.Div(
224
  [
225
  dcc.Store(id='memory-output', data={'df': df.to_json(),'img_conf':pd.DataFrame.from_dict(img_conf).to_json(),'img_paths':[i for i in img_paths]}),
 
268
  style={'padding': '20px', 'border': '1px solid #e0e0e0', 'border-radius': '10px'}
269
  )
270
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
271
  def clean_up(path,session_id):
272
  data_path = os.path.join(path,session_id)
273
  if os.path.exists(data_path):
 
361
  img_conf = {k:[v['0'],v['1']] for k,v in img_conf.items()}
362
  folders = images_by_label(img_paths)
363
  generic_cat, df = update_plot(prompt_x, prompt_y, clip, device, model, folders, img_conf, [query1, query2])
364
+ fig = show_scatter(prompt_x, prompt_y,generic_cat)
 
 
 
 
 
 
365
  data = {'df':df,'img_conf':img_conf,'img_paths':img_paths}
366
  return [fig,None,Serverside(data)]
367
  elif 'download-button' in changed_id:
 
370
  folders = images_by_label(img_paths)
371
  img_conf = get_embeddings(preprocess,model,device,img_paths,session['session_id'])
372
  generic_cat, df1 = update_plot(prompt_x, prompt_y, clip, device, model, folders, img_conf, [query1, query2])
373
+ fig = show_scatter(prompt_x, prompt_y,generic_cat)
 
 
 
 
 
374
  clean_up('image_data',session['session_id'])
375
  data = {'df':df1,'img_conf':img_conf,'img_paths':img_paths}
376
  return [fig,None,Serverside(data)]