peeyushsinghal commited on
Commit
0818965
1 Parent(s): a7c086c

changed file names

Browse files
Files changed (1) hide show
  1. app.py +4 -4
app.py CHANGED
@@ -29,7 +29,7 @@ import base64
29
  # else:
30
  # print("Google Drive is already mounted.")
31
 
32
- list_c1 = torch.load('list_mnist_m_non_dann_misclassified_dann_classified.pt')
33
 
34
  class CustomDataset(torch.utils.data.Dataset):
35
  def __init__(self, data):
@@ -54,7 +54,7 @@ def get_images():
54
  pil_images = [transform_to_pil(image) for image in images]
55
  return pil_images, labels.tolist()
56
 
57
- list_c2 = torch.load('list_mnist_m_non_dann_misclassified_dann_misclassified.pt')
58
  dataset_c2 = CustomDataset(list_c2)
59
  dataloader_c2 = torch.utils.data.DataLoader(dataset_c2, batch_size=10, shuffle=True)
60
  def get_images_2():
@@ -173,7 +173,7 @@ class Network(nn.Module):
173
  loaded_model_non_dann = Network()
174
  loaded_model_non_dann = loaded_model_non_dann.to(device)
175
  # Load the saved state dictionary
176
- loaded_model_non_dann.load_state_dict(torch.load('non_dann_26_06.pt', map_location=device), strict=False)
177
  loaded_model_non_dann.eval()
178
 
179
  ## DANN
@@ -181,7 +181,7 @@ loaded_model_non_dann.eval()
181
  loaded_model_dann = Network()
182
  loaded_model_dann = loaded_model_dann.to(device)
183
  # Load the saved state dictionary
184
- loaded_model_dann.load_state_dict(torch.load('dann_26_06.pt', map_location=device), strict=False)
185
  loaded_model_dann.eval()
186
 
187
  img_size = 28 # for mnist
 
29
  # else:
30
  # print("Google Drive is already mounted.")
31
 
32
+ list_c1 = torch.load('list_mnist_m_non_dann_misclassified_dann_classified_08_07.pt')
33
 
34
  class CustomDataset(torch.utils.data.Dataset):
35
  def __init__(self, data):
 
54
  pil_images = [transform_to_pil(image) for image in images]
55
  return pil_images, labels.tolist()
56
 
57
+ list_c2 = torch.load('list_mnist_m_non_dann_misclassified_dann_misclassified_08_07.pt')
58
  dataset_c2 = CustomDataset(list_c2)
59
  dataloader_c2 = torch.utils.data.DataLoader(dataset_c2, batch_size=10, shuffle=True)
60
  def get_images_2():
 
173
  loaded_model_non_dann = Network()
174
  loaded_model_non_dann = loaded_model_non_dann.to(device)
175
  # Load the saved state dictionary
176
+ loaded_model_non_dann.load_state_dict(torch.load('non_dann_08_07.pt', map_location=device), strict=False)
177
  loaded_model_non_dann.eval()
178
 
179
  ## DANN
 
181
  loaded_model_dann = Network()
182
  loaded_model_dann = loaded_model_dann.to(device)
183
  # Load the saved state dictionary
184
+ loaded_model_dann.load_state_dict(torch.load('dann_08_07.pt', map_location=device), strict=False)
185
  loaded_model_dann.eval()
186
 
187
  img_size = 28 # for mnist