Spaces:
Sleeping
Sleeping
vardaan123
commited on
Commit
•
85f8c1c
1
Parent(s):
e75805e
Update model.py
Browse files
model.py
CHANGED
@@ -8,26 +8,21 @@ from torchvision.models import resnet18, resnet50
|
|
8 |
from torchvision.models import ResNet18_Weights, ResNet50_Weights
|
9 |
|
10 |
class DistMult(nn.Module):
|
11 |
-
def __init__(self,
|
12 |
super(DistMult, self).__init__()
|
13 |
-
self.args = args
|
14 |
self.num_ent_uid = num_ent_uid
|
15 |
|
16 |
self.num_relations = 4
|
17 |
|
18 |
-
self.ent_embedding = torch.nn.Embedding(self.num_ent_uid,
|
19 |
-
self.rel_embedding = torch.nn.Embedding(self.num_relations,
|
20 |
|
21 |
-
self.location_embedding = MLP(
|
22 |
|
23 |
-
self.time_embedding = MLP(
|
24 |
|
25 |
-
|
26 |
-
|
27 |
-
self.image_embedding.fc = nn.Linear(2048, args.embedding_dim)
|
28 |
-
else:
|
29 |
-
self.image_embedding = resnet18(weights=ResNet18_Weights.IMAGENET1K_V1)
|
30 |
-
self.image_embedding.fc = nn.Linear(512, args.embedding_dim)
|
31 |
|
32 |
self.target_list = target_list
|
33 |
|
@@ -36,7 +31,6 @@ class DistMult(nn.Module):
|
|
36 |
if all_timestamps is not None:
|
37 |
self.all_timestamps = all_timestamps.to(device)
|
38 |
|
39 |
-
self.args = args
|
40 |
self.device = device
|
41 |
|
42 |
self.init()
|
|
|
8 |
from torchvision.models import ResNet18_Weights, ResNet50_Weights
|
9 |
|
10 |
class DistMult(nn.Module):
|
11 |
+
def __init__(self, num_ent_uid, target_list, device, all_locs=None, num_habitat=None, all_timestamps=None):
|
12 |
super(DistMult, self).__init__()
|
|
|
13 |
self.num_ent_uid = num_ent_uid
|
14 |
|
15 |
self.num_relations = 4
|
16 |
|
17 |
+
self.ent_embedding = torch.nn.Embedding(self.num_ent_uid, 512, sparse=False)
|
18 |
+
self.rel_embedding = torch.nn.Embedding(self.num_relations, 512, sparse=False)
|
19 |
|
20 |
+
self.location_embedding = MLP(2, 512, 3)
|
21 |
|
22 |
+
self.time_embedding = MLP(1, 512, 3)
|
23 |
|
24 |
+
self.image_embedding = resnet50(weights=ResNet50_Weights.IMAGENET1K_V1)
|
25 |
+
self.image_embedding.fc = nn.Linear(2048, 512)
|
|
|
|
|
|
|
|
|
26 |
|
27 |
self.target_list = target_list
|
28 |
|
|
|
31 |
if all_timestamps is not None:
|
32 |
self.all_timestamps = all_timestamps.to(device)
|
33 |
|
|
|
34 |
self.device = device
|
35 |
|
36 |
self.init()
|