Spaces:
Sleeping
Sleeping
peeyushsinghal
commited on
Commit
•
0818965
1
Parent(s):
a7c086c
changed file names
Browse files
app.py
CHANGED
@@ -29,7 +29,7 @@ import base64
|
|
29 |
# else:
|
30 |
# print("Google Drive is already mounted.")
|
31 |
|
32 |
-
list_c1 = torch.load('
|
33 |
|
34 |
class CustomDataset(torch.utils.data.Dataset):
|
35 |
def __init__(self, data):
|
@@ -54,7 +54,7 @@ def get_images():
|
|
54 |
pil_images = [transform_to_pil(image) for image in images]
|
55 |
return pil_images, labels.tolist()
|
56 |
|
57 |
-
list_c2 = torch.load('
|
58 |
dataset_c2 = CustomDataset(list_c2)
|
59 |
dataloader_c2 = torch.utils.data.DataLoader(dataset_c2, batch_size=10, shuffle=True)
|
60 |
def get_images_2():
|
@@ -173,7 +173,7 @@ class Network(nn.Module):
|
|
173 |
loaded_model_non_dann = Network()
|
174 |
loaded_model_non_dann = loaded_model_non_dann.to(device)
|
175 |
# Load the saved state dictionary
|
176 |
-
loaded_model_non_dann.load_state_dict(torch.load('
|
177 |
loaded_model_non_dann.eval()
|
178 |
|
179 |
## DANN
|
@@ -181,7 +181,7 @@ loaded_model_non_dann.eval()
|
|
181 |
loaded_model_dann = Network()
|
182 |
loaded_model_dann = loaded_model_dann.to(device)
|
183 |
# Load the saved state dictionary
|
184 |
-
loaded_model_dann.load_state_dict(torch.load('
|
185 |
loaded_model_dann.eval()
|
186 |
|
187 |
img_size = 28 # for mnist
|
|
|
29 |
# else:
|
30 |
# print("Google Drive is already mounted.")
|
31 |
|
32 |
+
list_c1 = torch.load('list_mnist_m_non_dann_misclassified_dann_classified_08_07.pt')
|
33 |
|
34 |
class CustomDataset(torch.utils.data.Dataset):
|
35 |
def __init__(self, data):
|
|
|
54 |
pil_images = [transform_to_pil(image) for image in images]
|
55 |
return pil_images, labels.tolist()
|
56 |
|
57 |
+
list_c2 = torch.load('list_mnist_m_non_dann_misclassified_dann_misclassified_08_07.pt')
|
58 |
dataset_c2 = CustomDataset(list_c2)
|
59 |
dataloader_c2 = torch.utils.data.DataLoader(dataset_c2, batch_size=10, shuffle=True)
|
60 |
def get_images_2():
|
|
|
173 |
loaded_model_non_dann = Network()
|
174 |
loaded_model_non_dann = loaded_model_non_dann.to(device)
|
175 |
# Load the saved state dictionary
|
176 |
+
loaded_model_non_dann.load_state_dict(torch.load('non_dann_08_07.pt', map_location=device), strict=False)
|
177 |
loaded_model_non_dann.eval()
|
178 |
|
179 |
## DANN
|
|
|
181 |
loaded_model_dann = Network()
|
182 |
loaded_model_dann = loaded_model_dann.to(device)
|
183 |
# Load the saved state dictionary
|
184 |
+
loaded_model_dann.load_state_dict(torch.load('dann_08_07.pt', map_location=device), strict=False)
|
185 |
loaded_model_dann.eval()
|
186 |
|
187 |
img_size = 28 # for mnist
|