LUAR-MUD / model.py
rrivera1849's picture
Upload LUAR
5773bf8
raw
history blame
3.16 kB
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange, reduce, repeat
from transformers import AutoModel, PreTrainedModel
from .config import LUARConfig
class SelfAttention(nn.Module):
"""Implements Dot-Product Self-Attention as used in "Attention is all You Need".
"""
def __init__(self):
super(SelfAttention, self).__init__()
def forward(self, k, q, v):
d_k = q.size(-1)
scores = torch.matmul(k, q.transpose(-2, -1)) / math.sqrt(d_k)
p_attn = F.softmax(scores, dim=-1)
return torch.matmul(p_attn, v)
class LUAR(PreTrainedModel):
"""Defines the LUAR model.
"""
config_class = LUARConfig
def __init__(self, config):
super().__init__(config)
self.create_transformer()
self.attn_fn = SelfAttention()
self.linear = nn.Linear(self.hidden_size, config.embedding_size)
def create_transformer(self):
"""Creates the Transformer backbone.
"""
self.transformer = AutoModel.from_pretrained("sentence-transformers/paraphrase-distilroberta-base-v1")
self.hidden_size = self.transformer.config.hidden_size
self.num_attention_heads = self.transformer.config.num_attention_heads
self.dim_head = self.hidden_size // self.num_attention_heads
def mean_pooling(self, token_embeddings, attention_mask):
"""Mean Pooling as described in the SBERT paper.
"""
input_mask_expanded = repeat(attention_mask, 'b l -> b l d', d=self.hidden_size).float()
sum_embeddings = reduce(token_embeddings * input_mask_expanded, 'b l d -> b d', 'sum')
sum_mask = torch.clamp(reduce(input_mask_expanded, 'b l d -> b d', 'sum'), min=1e-9)
return sum_embeddings / sum_mask
def get_episode_embeddings(self, input_ids, attention_mask):
"""Computes the Author Embedding.
"""
B, E, _ = attention_mask.shape
input_ids = rearrange(input_ids, 'b e l -> (b e) l')
attention_mask = rearrange(attention_mask, 'b e l -> (b e) l')
outputs = self.transformer(
input_ids=input_ids,
attention_mask=attention_mask,
return_dict=True,
output_hidden_states=True
)
# at this point, we're embedding individual "comments"
comment_embeddings = self.mean_pooling(outputs['last_hidden_state'], attention_mask)
comment_embeddings = rearrange(comment_embeddings, '(b e) l -> b e l', b=B, e=E)
# aggregate individual comments embeddings into episode embeddings
episode_embeddings = self.attn_fn(comment_embeddings, comment_embeddings, comment_embeddings)
episode_embeddings = reduce(episode_embeddings, 'b e l -> b l', 'max')
episode_embeddings = self.linear(episode_embeddings)
return episode_embeddings
def forward(self, input_ids, attention_mask):
"""Calculates a fixed-length feature vector for a batch of episode samples.
"""
output = self.get_episode_embeddings(input_ids, attention_mask)
return output