|
|
|
import math |
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from einops import rearrange, reduce, repeat |
|
from transformers import AutoModel, PreTrainedModel |
|
|
|
from .config import LUARConfig |
|
|
|
class SelfAttention(nn.Module): |
|
"""Implements Dot-Product Self-Attention as used in "Attention is all You Need". |
|
""" |
|
def __init__(self): |
|
super(SelfAttention, self).__init__() |
|
|
|
def forward(self, k, q, v): |
|
d_k = q.size(-1) |
|
scores = torch.matmul(k, q.transpose(-2, -1)) / math.sqrt(d_k) |
|
p_attn = F.softmax(scores, dim=-1) |
|
|
|
return torch.matmul(p_attn, v) |
|
|
|
class LUAR(PreTrainedModel): |
|
"""Defines the LUAR model. |
|
""" |
|
config_class = LUARConfig |
|
|
|
def __init__(self, config): |
|
super().__init__(config) |
|
self.create_transformer() |
|
self.attn_fn = SelfAttention() |
|
self.linear = nn.Linear(self.hidden_size, config.embedding_size) |
|
|
|
def create_transformer(self): |
|
"""Creates the Transformer backbone. |
|
""" |
|
self.transformer = AutoModel.from_pretrained("sentence-transformers/paraphrase-distilroberta-base-v1") |
|
self.hidden_size = self.transformer.config.hidden_size |
|
self.num_attention_heads = self.transformer.config.num_attention_heads |
|
self.dim_head = self.hidden_size // self.num_attention_heads |
|
|
|
def mean_pooling(self, token_embeddings, attention_mask): |
|
"""Mean Pooling as described in the SBERT paper. |
|
""" |
|
input_mask_expanded = repeat(attention_mask, 'b l -> b l d', d=self.hidden_size).float() |
|
sum_embeddings = reduce(token_embeddings * input_mask_expanded, 'b l d -> b d', 'sum') |
|
sum_mask = torch.clamp(reduce(input_mask_expanded, 'b l d -> b d', 'sum'), min=1e-9) |
|
return sum_embeddings / sum_mask |
|
|
|
def get_episode_embeddings(self, input_ids, attention_mask): |
|
"""Computes the Author Embedding. |
|
""" |
|
B, E, _ = attention_mask.shape |
|
|
|
input_ids = rearrange(input_ids, 'b e l -> (b e) l') |
|
attention_mask = rearrange(attention_mask, 'b e l -> (b e) l') |
|
|
|
outputs = self.transformer( |
|
input_ids=input_ids, |
|
attention_mask=attention_mask, |
|
return_dict=True, |
|
output_hidden_states=True |
|
) |
|
|
|
|
|
comment_embeddings = self.mean_pooling(outputs['last_hidden_state'], attention_mask) |
|
comment_embeddings = rearrange(comment_embeddings, '(b e) l -> b e l', b=B, e=E) |
|
|
|
|
|
episode_embeddings = self.attn_fn(comment_embeddings, comment_embeddings, comment_embeddings) |
|
episode_embeddings = reduce(episode_embeddings, 'b e l -> b l', 'max') |
|
|
|
episode_embeddings = self.linear(episode_embeddings) |
|
|
|
return episode_embeddings |
|
|
|
def forward(self, input_ids, attention_mask): |
|
"""Calculates a fixed-length feature vector for a batch of episode samples. |
|
""" |
|
output = self.get_episode_embeddings(input_ids, attention_mask) |
|
|
|
return output |