ludusc commited on
Commit
3b15df8
1 Parent(s): 1a39dfb

added cv2 and show sample image

Browse files
backend/disentangle_concepts.py CHANGED
@@ -45,3 +45,13 @@ def regenerate_images(model, z, decision_boundary, min_epsilon=-3, max_epsilon=3
45
  img = (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8)
46
  images.append(PIL.Image.fromarray(img[0].cpu().numpy(), 'RGB'))
47
  return images
 
 
 
 
 
 
 
 
 
 
 
45
  img = (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8)
46
  images.append(PIL.Image.fromarray(img[0].cpu().numpy(), 'RGB'))
47
  return images
48
+
49
+ def generate_original_image(z, model):
50
+ device = torch.device('cpu')
51
+ G = model.to(device) # type: ignore
52
+ # Labels.
53
+ label = torch.zeros([1, G.c_dim], device=device)
54
+ z = torch.from_numpy(z.copy()).to(device)
55
+ img = G(z, label, truncation_psi=0.7, noise_mode='random')
56
+ img = (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8)
57
+ return PIL.Image.fromarray(img[0].cpu().numpy(), 'RGB')
pages/1_Disentanglement.py CHANGED
@@ -128,11 +128,18 @@ with input_col_3:
128
  epsilon_button = st.form_submit_button('Choose the defined epsilon')
129
 
130
  # ---------------------------- DISPLAY COL 2 ROW 1 ------------------------------
131
- # original_image_dict = load_image(st.session_state.image_id)
 
 
 
132
  # input_image = original_image_dict['image']
133
  # input_label = original_image_dict['label']
134
  # input_id = original_image_dict['id']
135
 
 
 
 
 
136
 
137
 
138
  # if pred_class_id == (input_id-1) and (epsilon_button or find_epsilon or setting_button):
 
128
  epsilon_button = st.form_submit_button('Choose the defined epsilon')
129
 
130
  # ---------------------------- DISPLAY COL 2 ROW 1 ------------------------------
131
+
132
+ model = torch.load('./data/model_files/pytorch_model.bin')
133
+ original_image_vec = annotations['z_vectors'][st.session_state.image_id]
134
+ img = generate_original_image(original_image_vec, model)
135
  # input_image = original_image_dict['image']
136
  # input_label = original_image_dict['label']
137
  # input_id = original_image_dict['id']
138
 
139
+ with smoothgrad_col_3:
140
+ st.image(img)
141
+ header_col_1.write(f'Base image')
142
+
143
 
144
 
145
  # if pred_class_id == (input_id-1) and (epsilon_button or find_epsilon or setting_button):
requirements.txt CHANGED
@@ -12,4 +12,5 @@ torchvision==0.11.2
12
  tqdm==4.64.1
13
  transformers==4.25.1
14
  scikit-learn
15
- altair==4.0
 
 
12
  tqdm==4.64.1
13
  transformers==4.25.1
14
  scikit-learn
15
+ altair==4.0
16
+ opencv-python