COSMO / model.py
vardaan123's picture
Upload folder using huggingface_hub
3dba732 verified
raw
history blame
No virus
3.97 kB
import torch
import torch.nn as nn
from torch.nn import functional as F
from torch import Tensor
from typing import Tuple
from torchvision.models import resnet18, resnet50
from torchvision.models import ResNet18_Weights, ResNet50_Weights
class DistMult(nn.Module):
def __init__(self, num_ent_uid, target_list, device, all_locs=None, num_habitat=None, all_timestamps=None):
super(DistMult, self).__init__()
self.num_ent_uid = num_ent_uid
self.num_relations = 4
self.ent_embedding = torch.nn.Embedding(self.num_ent_uid, 512, sparse=False)
self.rel_embedding = torch.nn.Embedding(self.num_relations, 512, sparse=False)
self.location_embedding = MLP(2, 512, 3)
self.time_embedding = MLP(1, 512, 3)
self.image_embedding = resnet50(weights=ResNet50_Weights.IMAGENET1K_V1)
self.image_embedding.fc = nn.Linear(2048, 512)
self.target_list = target_list
if all_locs is not None:
self.all_locs = all_locs.to(device)
if all_timestamps is not None:
self.all_timestamps = all_timestamps.to(device)
self.device = device
self.init()
def init(self):
nn.init.xavier_uniform_(self.ent_embedding.weight.data)
nn.init.xavier_uniform_(self.rel_embedding.weight.data)
nn.init.xavier_uniform_(self.image_embedding.fc.weight.data)
def forward_ce(self, h, r, triple_type=None):
emb_h = self.batch_embedding_concat_h(h) # [batch, hid]
emb_r = self.rel_embedding(r.squeeze(-1)) # [batch, hid]
emb_hr = emb_h * emb_r # [batch, hid]
if triple_type == ('image', 'id'):
score = torch.mm(emb_hr, self.ent_embedding.weight[self.target_list.squeeze(-1)].T) # [batch, n_ent]
elif triple_type == ('id', 'id'):
score = torch.mm(emb_hr, self.ent_embedding.weight.T) # [batch, n_ent]
elif triple_type == ('image', 'location'):
loc_emb = self.location_embedding(self.all_locs) # computed for each batch
score = torch.mm(emb_hr, loc_emb.T)
elif triple_type == ('image', 'time'):
time_emb = self.time_embedding(self.all_timestamps)
score = torch.mm(emb_hr, time_emb.T)
else:
raise NotImplementedError
return score
def batch_embedding_concat_h(self, e1):
e1_embedded = None
if len(e1.size())==1 or e1.size(1) == 1: # uid
# print('ent_embedding = {}'.format(self.ent_embedding.weight.size()))
e1_embedded = self.ent_embedding(e1.squeeze(-1))
elif e1.size(1) == 15: # time
e1_embedded = self.time_embedding(e1)
elif e1.size(1) == 2: # GPS
e1_embedded = self.location_embedding(e1)
elif e1.size(1) == 3: # Image
e1_embedded = self.image_embedding(e1)
return e1_embedded
class MLP(nn.Module):
def __init__(self,
input_dim,
output_dim,
num_layers=3,
p_dropout=0.0,
bias=True):
super().__init__()
self.input_dim = input_dim
self.output_dim = output_dim
self.p_dropout = p_dropout
step_size = (input_dim - output_dim) // num_layers
hidden_dims = [output_dim + (i * step_size)
for i in reversed(range(num_layers))]
mlp = list()
layer_indim = input_dim
for hidden_dim in hidden_dims:
mlp.extend([nn.Linear(layer_indim, hidden_dim, bias),
nn.Dropout(p=self.p_dropout, inplace=True),
nn.PReLU()])
layer_indim = hidden_dim
self.mlp = nn.Sequential(*mlp)
# initialize weights
self.init()
def forward(self, x):
return self.mlp(x)
def init(self):
for param in self.parameters():
nn.init.uniform_(param)