import math import torch import torch.nn as nn import torch.nn.functional as F from einops import rearrange, reduce, repeat from transformers import AutoModel, PreTrainedModel from .config import LUARConfig class SelfAttention(nn.Module): """Implements Dot-Product Self-Attention as used in "Attention is all You Need". """ def __init__(self): super(SelfAttention, self).__init__() def forward(self, k, q, v): d_k = q.size(-1) scores = torch.matmul(k, q.transpose(-2, -1)) / math.sqrt(d_k) p_attn = F.softmax(scores, dim=-1) return torch.matmul(p_attn, v) class LUAR(PreTrainedModel): """Defines the LUAR model. """ config_class = LUARConfig def __init__(self, config): super().__init__(config) self.create_transformer() self.attn_fn = SelfAttention() self.linear = nn.Linear(self.hidden_size, config.embedding_size) def create_transformer(self): """Creates the Transformer backbone. """ self.transformer = AutoModel.from_pretrained("sentence-transformers/paraphrase-distilroberta-base-v1") self.hidden_size = self.transformer.config.hidden_size self.num_attention_heads = self.transformer.config.num_attention_heads self.dim_head = self.hidden_size // self.num_attention_heads def mean_pooling(self, token_embeddings, attention_mask): """Mean Pooling as described in the SBERT paper. """ input_mask_expanded = repeat(attention_mask, 'b l -> b l d', d=self.hidden_size).float() sum_embeddings = reduce(token_embeddings * input_mask_expanded, 'b l d -> b d', 'sum') sum_mask = torch.clamp(reduce(input_mask_expanded, 'b l d -> b d', 'sum'), min=1e-9) return sum_embeddings / sum_mask def get_episode_embeddings(self, input_ids, attention_mask): """Computes the Author Embedding. """ B, E, _ = attention_mask.shape input_ids = rearrange(input_ids, 'b e l -> (b e) l') attention_mask = rearrange(attention_mask, 'b e l -> (b e) l') outputs = self.transformer( input_ids=input_ids, attention_mask=attention_mask, return_dict=True, output_hidden_states=True ) # at this point, we're embedding individual "comments" comment_embeddings = self.mean_pooling(outputs['last_hidden_state'], attention_mask) comment_embeddings = rearrange(comment_embeddings, '(b e) l -> b e l', b=B, e=E) # aggregate individual comments embeddings into episode embeddings episode_embeddings = self.attn_fn(comment_embeddings, comment_embeddings, comment_embeddings) episode_embeddings = reduce(episode_embeddings, 'b e l -> b l', 'max') episode_embeddings = self.linear(episode_embeddings) return episode_embeddings def forward(self, input_ids, attention_mask): """Calculates a fixed-length feature vector for a batch of episode samples. """ output = self.get_episode_embeddings(input_ids, attention_mask) return output