vardaan123 commited on
Commit
85f8c1c
1 Parent(s): e75805e

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +7 -13
model.py CHANGED
@@ -8,26 +8,21 @@ from torchvision.models import resnet18, resnet50
8
  from torchvision.models import ResNet18_Weights, ResNet50_Weights
9
 
10
  class DistMult(nn.Module):
11
- def __init__(self, args, num_ent_uid, target_list, device, all_locs=None, num_habitat=None, all_timestamps=None):
12
  super(DistMult, self).__init__()
13
- self.args = args
14
  self.num_ent_uid = num_ent_uid
15
 
16
  self.num_relations = 4
17
 
18
- self.ent_embedding = torch.nn.Embedding(self.num_ent_uid, args.embedding_dim, sparse=False)
19
- self.rel_embedding = torch.nn.Embedding(self.num_relations, args.embedding_dim, sparse=False)
20
 
21
- self.location_embedding = MLP(args.location_input_dim, args.embedding_dim, args.mlp_location_numlayer)
22
 
23
- self.time_embedding = MLP(args.time_input_dim, args.embedding_dim, args.mlp_time_numlayer)
24
 
25
- if self.args.img_embed_model == 'resnet50':
26
- self.image_embedding = resnet50(weights=ResNet50_Weights.IMAGENET1K_V1)
27
- self.image_embedding.fc = nn.Linear(2048, args.embedding_dim)
28
- else:
29
- self.image_embedding = resnet18(weights=ResNet18_Weights.IMAGENET1K_V1)
30
- self.image_embedding.fc = nn.Linear(512, args.embedding_dim)
31
 
32
  self.target_list = target_list
33
 
@@ -36,7 +31,6 @@ class DistMult(nn.Module):
36
  if all_timestamps is not None:
37
  self.all_timestamps = all_timestamps.to(device)
38
 
39
- self.args = args
40
  self.device = device
41
 
42
  self.init()
 
8
  from torchvision.models import ResNet18_Weights, ResNet50_Weights
9
 
10
  class DistMult(nn.Module):
11
+ def __init__(self, num_ent_uid, target_list, device, all_locs=None, num_habitat=None, all_timestamps=None):
12
  super(DistMult, self).__init__()
 
13
  self.num_ent_uid = num_ent_uid
14
 
15
  self.num_relations = 4
16
 
17
+ self.ent_embedding = torch.nn.Embedding(self.num_ent_uid, 512, sparse=False)
18
+ self.rel_embedding = torch.nn.Embedding(self.num_relations, 512, sparse=False)
19
 
20
+ self.location_embedding = MLP(2, 512, 3)
21
 
22
+ self.time_embedding = MLP(1, 512, 3)
23
 
24
+ self.image_embedding = resnet50(weights=ResNet50_Weights.IMAGENET1K_V1)
25
+ self.image_embedding.fc = nn.Linear(2048, 512)
 
 
 
 
26
 
27
  self.target_list = target_list
28
 
 
31
  if all_timestamps is not None:
32
  self.all_timestamps = all_timestamps.to(device)
33
 
 
34
  self.device = device
35
 
36
  self.init()