Spaces:
Sleeping
Sleeping
import torch | |
import torch.nn as nn | |
from torch.nn import functional as F | |
from torch import Tensor | |
from typing import Tuple | |
from torchvision.models import resnet18, resnet50 | |
from torchvision.models import ResNet18_Weights, ResNet50_Weights | |
class DistMult(nn.Module): | |
def __init__(self, num_ent_uid, target_list, device, all_locs=None, num_habitat=None, all_timestamps=None): | |
super(DistMult, self).__init__() | |
self.num_ent_uid = num_ent_uid | |
self.num_relations = 4 | |
self.ent_embedding = torch.nn.Embedding(self.num_ent_uid, 512, sparse=False) | |
self.rel_embedding = torch.nn.Embedding(self.num_relations, 512, sparse=False) | |
self.location_embedding = MLP(2, 512, 3) | |
self.time_embedding = MLP(1, 512, 3) | |
self.image_embedding = resnet50(weights=ResNet50_Weights.IMAGENET1K_V1) | |
self.image_embedding.fc = nn.Linear(2048, 512) | |
self.target_list = target_list | |
if all_locs is not None: | |
self.all_locs = all_locs.to(device) | |
if all_timestamps is not None: | |
self.all_timestamps = all_timestamps.to(device) | |
self.device = device | |
self.init() | |
def init(self): | |
nn.init.xavier_uniform_(self.ent_embedding.weight.data) | |
nn.init.xavier_uniform_(self.rel_embedding.weight.data) | |
nn.init.xavier_uniform_(self.image_embedding.fc.weight.data) | |
def forward_ce(self, h, r, triple_type=None): | |
emb_h = self.batch_embedding_concat_h(h) # [batch, hid] | |
emb_r = self.rel_embedding(r.squeeze(-1)) # [batch, hid] | |
emb_hr = emb_h * emb_r # [batch, hid] | |
if triple_type == ('image', 'id'): | |
score = torch.mm(emb_hr, self.ent_embedding.weight[self.target_list.squeeze(-1)].T) # [batch, n_ent] | |
elif triple_type == ('id', 'id'): | |
score = torch.mm(emb_hr, self.ent_embedding.weight.T) # [batch, n_ent] | |
elif triple_type == ('image', 'location'): | |
loc_emb = self.location_embedding(self.all_locs) # computed for each batch | |
score = torch.mm(emb_hr, loc_emb.T) | |
elif triple_type == ('image', 'time'): | |
time_emb = self.time_embedding(self.all_timestamps) | |
score = torch.mm(emb_hr, time_emb.T) | |
else: | |
raise NotImplementedError | |
return score | |
def batch_embedding_concat_h(self, e1): | |
e1_embedded = None | |
if len(e1.size())==1 or e1.size(1) == 1: # uid | |
# print('ent_embedding = {}'.format(self.ent_embedding.weight.size())) | |
e1_embedded = self.ent_embedding(e1.squeeze(-1)) | |
elif e1.size(1) == 15: # time | |
e1_embedded = self.time_embedding(e1) | |
elif e1.size(1) == 2: # GPS | |
e1_embedded = self.location_embedding(e1) | |
elif e1.size(1) == 3: # Image | |
e1_embedded = self.image_embedding(e1) | |
return e1_embedded | |
class MLP(nn.Module): | |
def __init__(self, | |
input_dim, | |
output_dim, | |
num_layers=3, | |
p_dropout=0.0, | |
bias=True): | |
super().__init__() | |
self.input_dim = input_dim | |
self.output_dim = output_dim | |
self.p_dropout = p_dropout | |
step_size = (input_dim - output_dim) // num_layers | |
hidden_dims = [output_dim + (i * step_size) | |
for i in reversed(range(num_layers))] | |
mlp = list() | |
layer_indim = input_dim | |
for hidden_dim in hidden_dims: | |
mlp.extend([nn.Linear(layer_indim, hidden_dim, bias), | |
nn.Dropout(p=self.p_dropout, inplace=True), | |
nn.PReLU()]) | |
layer_indim = hidden_dim | |
self.mlp = nn.Sequential(*mlp) | |
# initialize weights | |
self.init() | |
def forward(self, x): | |
return self.mlp(x) | |
def init(self): | |
for param in self.parameters(): | |
nn.init.uniform_(param) |