|
|
|
import math |
|
from functools import partial |
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from einops import rearrange, reduce, repeat |
|
from torch.utils.checkpoint import checkpoint |
|
from transformers import AutoModel, PreTrainedModel |
|
|
|
from .config import LUARConfig |
|
|
|
|
|
|
|
|
|
def exists(val): |
|
return val is not None |
|
|
|
def summarize_qkv_chunk( |
|
q, k, v, |
|
mask |
|
): |
|
"""Dot-Product Attention for a chunk of queries, keys, and values. |
|
""" |
|
weight = torch.einsum('b h i d, b h j d -> b h i j', q, k) |
|
|
|
if exists(mask): |
|
|
|
weight += mask |
|
|
|
weight_max = weight.amax(dim = -1, keepdim = True).detach() |
|
weight = weight - weight_max |
|
|
|
exp_weight = weight.exp() |
|
weighted_value = torch.einsum('b h i j, b h j d -> b h i d', exp_weight, v) |
|
|
|
return exp_weight.sum(dim = -1), weighted_value, rearrange(weight_max, '... 1 -> ...') |
|
|
|
checkpointed_summarize_qkv_chunk = partial(checkpoint, summarize_qkv_chunk) |
|
|
|
def memory_efficient_attention( |
|
q, k, v, |
|
mask = None, |
|
q_bucket_size = 512, |
|
k_bucket_size = 1024, |
|
eps = 1e-8 |
|
): |
|
scale = q.shape[-1] ** -0.5 |
|
q = q * scale |
|
|
|
|
|
needs_backwards = q.requires_grad or k.requires_grad or v.requires_grad |
|
summarize_qkv_fn = checkpointed_summarize_qkv_chunk if needs_backwards else summarize_qkv_chunk |
|
|
|
|
|
q_chunks = q.split(q_bucket_size, dim = -2) |
|
k_chunks = k.split(k_bucket_size, dim = -2) |
|
v_chunks = v.split(k_bucket_size, dim = -2) |
|
mask_chunks = mask.split(k_bucket_size, dim = -1) if exists(mask) else ((None,) * len(k_chunks)) |
|
|
|
|
|
out = [] |
|
for q_index, q_chunk in enumerate(q_chunks): |
|
exp_weights = [] |
|
weighted_values = [] |
|
weight_maxes = [] |
|
|
|
for k_index, (k_chunk, v_chunk, mask_chunk) in enumerate(zip(k_chunks, v_chunks, mask_chunks)): |
|
|
|
exp_weight_chunk, weighted_value_chunk, weight_max_chunk = summarize_qkv_fn( |
|
q_chunk, |
|
k_chunk, |
|
v_chunk, |
|
mask_chunk, |
|
) |
|
|
|
exp_weights.append(exp_weight_chunk) |
|
weighted_values.append(weighted_value_chunk) |
|
weight_maxes.append(weight_max_chunk) |
|
|
|
exp_weights = torch.stack(exp_weights, dim = -1) |
|
weighted_values = torch.stack(weighted_values, dim = -1) |
|
weight_maxes = torch.stack(weight_maxes, dim = -1) |
|
|
|
global_max = weight_maxes.amax(dim = -1, keepdim = True) |
|
renorm_factor = (weight_maxes - global_max).exp().detach() |
|
|
|
exp_weights = exp_weights * renorm_factor |
|
weighted_values = weighted_values * rearrange(renorm_factor, '... c -> ... 1 c') |
|
|
|
all_values = weighted_values.sum(dim = -1) |
|
all_weights = exp_weights.sum(dim = -1) |
|
|
|
normalized_values = all_values / (rearrange(all_weights, '... -> ... 1') + eps) |
|
out.append(normalized_values) |
|
|
|
return torch.cat(out, dim=-2) |
|
|
|
class SelfAttention(nn.Module): |
|
"""Implements Dot-Product Self-Attention as used in "Attention is all You Need". |
|
""" |
|
def __init__( |
|
self, |
|
memory_efficient_attention=False, |
|
q_bucket_size=512, |
|
k_bucket_size=1024, |
|
): |
|
super(SelfAttention, self).__init__() |
|
self.use_memory_efficient_attention = memory_efficient_attention |
|
self.q_bucket_size = q_bucket_size |
|
self.k_bucket_size = k_bucket_size |
|
|
|
def forward(self, k, q, v): |
|
|
|
if self.use_memory_efficient_attention: |
|
q, k, v = map( |
|
lambda t: rearrange(t, 'b n (h d) -> b h n d', h = 12), |
|
(q, k, v) |
|
) |
|
|
|
out = memory_efficient_attention( |
|
q, k, v, |
|
q_bucket_size=self.q_bucket_size, |
|
k_bucket_size=self.k_bucket_size |
|
) |
|
out = rearrange(out, 'b h n d -> b n (h d)') |
|
return out |
|
else: |
|
d_k = q.size(-1) |
|
scores = torch.matmul(k, q.transpose(-2, -1)) / math.sqrt(d_k) |
|
p_attn = F.softmax(scores, dim=-1) |
|
return torch.matmul(p_attn, v) |
|
|
|
class LUAR(PreTrainedModel): |
|
"""Defines the LUAR model. |
|
""" |
|
config_class = LUARConfig |
|
|
|
def __init__(self, config): |
|
super().__init__(config) |
|
self.create_transformer() |
|
self.attn_fn = SelfAttention( |
|
config.use_memory_efficient_attention, |
|
config.q_bucket_size, |
|
config.k_bucket_size, |
|
) |
|
self.linear = nn.Linear(self.hidden_size, config.embedding_size) |
|
|
|
def create_transformer(self): |
|
"""Creates the Transformer backbone. |
|
""" |
|
self.transformer = AutoModel.from_pretrained("sentence-transformers/paraphrase-distilroberta-base-v1") |
|
self.hidden_size = self.transformer.config.hidden_size |
|
self.num_attention_heads = self.transformer.config.num_attention_heads |
|
self.dim_head = self.hidden_size // self.num_attention_heads |
|
|
|
def mean_pooling(self, token_embeddings, attention_mask): |
|
"""Mean Pooling as described in the SBERT paper. |
|
""" |
|
input_mask_expanded = repeat(attention_mask, 'b l -> b l d', d=self.hidden_size).type(token_embeddings.type()) |
|
sum_embeddings = reduce(token_embeddings * input_mask_expanded, 'b l d -> b d', 'sum') |
|
sum_mask = torch.clamp(reduce(input_mask_expanded, 'b l d -> b d', 'sum'), min=1e-9) |
|
return sum_embeddings / sum_mask |
|
|
|
def get_episode_embeddings(self, input_ids, attention_mask, output_attentions=False, document_batch_size=0): |
|
"""Computes the Author Embedding. |
|
""" |
|
B, E, _ = attention_mask.shape |
|
|
|
input_ids = rearrange(input_ids, 'b e l -> (b e) l') |
|
attention_mask = rearrange(attention_mask, 'b e l -> (b e) l') |
|
|
|
if document_batch_size > 0: |
|
outputs = {"last_hidden_state": [], "attentions": []} |
|
for i in range(0, len(input_ids), document_batch_size): |
|
out = self.transformer( |
|
input_ids=input_ids[i:i+document_batch_size], |
|
attention_mask=attention_mask[i:i+document_batch_size], |
|
return_dict=True, |
|
output_hidden_states=False, |
|
output_attentions=output_attentions, |
|
) |
|
outputs["last_hidden_state"].append(out["last_hidden_state"]) |
|
if output_attentions: |
|
outputs["attentions"].append(out["attentions"]) |
|
outputs["last_hidden_state"] = torch.cat(outputs["last_hidden_state"], dim=0) |
|
if output_attentions: |
|
outputs["attentions"] = tuple([torch.cat([x[i] for x in outputs["attentions"]], dim=0) for i in range(len(outputs["attentions"][0]))]) |
|
else: |
|
outputs = self.transformer( |
|
input_ids=input_ids, |
|
attention_mask=attention_mask, |
|
return_dict=True, |
|
output_hidden_states=False, |
|
output_attentions=output_attentions, |
|
) |
|
|
|
|
|
comment_embeddings = self.mean_pooling(outputs['last_hidden_state'], attention_mask) |
|
comment_embeddings = rearrange(comment_embeddings, '(b e) l -> b e l', b=B, e=E) |
|
|
|
|
|
episode_embeddings = self.attn_fn(comment_embeddings, comment_embeddings, comment_embeddings) |
|
episode_embeddings = reduce(episode_embeddings, 'b e l -> b l', 'max') |
|
|
|
episode_embeddings = self.linear(episode_embeddings) |
|
|
|
if output_attentions: |
|
return episode_embeddings, outputs["attentions"] |
|
|
|
return episode_embeddings |
|
|
|
def forward(self, input_ids, attention_mask, output_attentions=False, document_batch_size=0): |
|
"""Calculates a fixed-length feature vector for a batch of episode samples. |
|
""" |
|
output = self.get_episode_embeddings(input_ids, attention_mask, output_attentions, document_batch_size) |
|
|
|
return output |
|
|