LUAR-MUD / model.py
rrivera1849's picture
Upload LUAR
afd830f
raw
history blame
8.33 kB
import math
from functools import partial
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange, reduce, repeat
from torch.utils.checkpoint import checkpoint
from transformers import AutoModel, PreTrainedModel
from .config import LUARConfig
# Adapted LucidRains impl. of Memory Efficient Attention
# https://github.com/lucidrains/memory-efficient-attention-pytorch
def exists(val):
return val is not None
def summarize_qkv_chunk(
q, k, v,
mask
):
"""Dot-Product Attention for a chunk of queries, keys, and values.
"""
weight = torch.einsum('b h i d, b h j d -> b h i j', q, k)
if exists(mask):
# HuggingFace masks have to be added:
weight += mask
weight_max = weight.amax(dim = -1, keepdim = True).detach()
weight = weight - weight_max
exp_weight = weight.exp()
weighted_value = torch.einsum('b h i j, b h j d -> b h i d', exp_weight, v)
return exp_weight.sum(dim = -1), weighted_value, rearrange(weight_max, '... 1 -> ...')
checkpointed_summarize_qkv_chunk = partial(checkpoint, summarize_qkv_chunk)
def memory_efficient_attention(
q, k, v,
mask = None,
q_bucket_size = 512,
k_bucket_size = 1024,
eps = 1e-8
):
scale = q.shape[-1] ** -0.5
q = q * scale
# function
needs_backwards = q.requires_grad or k.requires_grad or v.requires_grad
summarize_qkv_fn = checkpointed_summarize_qkv_chunk if needs_backwards else summarize_qkv_chunk
# chunk all the inputs
q_chunks = q.split(q_bucket_size, dim = -2)
k_chunks = k.split(k_bucket_size, dim = -2)
v_chunks = v.split(k_bucket_size, dim = -2)
mask_chunks = mask.split(k_bucket_size, dim = -1) if exists(mask) else ((None,) * len(k_chunks))
# loop through all chunks and accumulate
out = []
for q_index, q_chunk in enumerate(q_chunks):
exp_weights = []
weighted_values = []
weight_maxes = []
for k_index, (k_chunk, v_chunk, mask_chunk) in enumerate(zip(k_chunks, v_chunks, mask_chunks)):
exp_weight_chunk, weighted_value_chunk, weight_max_chunk = summarize_qkv_fn(
q_chunk,
k_chunk,
v_chunk,
mask_chunk,
)
exp_weights.append(exp_weight_chunk)
weighted_values.append(weighted_value_chunk)
weight_maxes.append(weight_max_chunk)
exp_weights = torch.stack(exp_weights, dim = -1)
weighted_values = torch.stack(weighted_values, dim = -1)
weight_maxes = torch.stack(weight_maxes, dim = -1)
global_max = weight_maxes.amax(dim = -1, keepdim = True)
renorm_factor = (weight_maxes - global_max).exp().detach()
exp_weights = exp_weights * renorm_factor
weighted_values = weighted_values * rearrange(renorm_factor, '... c -> ... 1 c')
all_values = weighted_values.sum(dim = -1)
all_weights = exp_weights.sum(dim = -1)
normalized_values = all_values / (rearrange(all_weights, '... -> ... 1') + eps)
out.append(normalized_values)
return torch.cat(out, dim=-2)
class SelfAttention(nn.Module):
"""Implements Dot-Product Self-Attention as used in "Attention is all You Need".
"""
def __init__(
self,
memory_efficient_attention=False,
q_bucket_size=512,
k_bucket_size=1024,
):
super(SelfAttention, self).__init__()
self.use_memory_efficient_attention = memory_efficient_attention
self.q_bucket_size = q_bucket_size
self.k_bucket_size = k_bucket_size
def forward(self, k, q, v):
if self.use_memory_efficient_attention:
q, k, v = map(
lambda t: rearrange(t, 'b n (h d) -> b h n d', h = 12),
(q, k, v)
)
out = memory_efficient_attention(
q, k, v,
q_bucket_size=self.q_bucket_size,
k_bucket_size=self.k_bucket_size
)
out = rearrange(out, 'b h n d -> b n (h d)')
return out
else:
d_k = q.size(-1)
scores = torch.matmul(k, q.transpose(-2, -1)) / math.sqrt(d_k)
p_attn = F.softmax(scores, dim=-1)
return torch.matmul(p_attn, v)
class LUAR(PreTrainedModel):
"""Defines the LUAR model.
"""
config_class = LUARConfig
def __init__(self, config):
super().__init__(config)
self.create_transformer()
self.attn_fn = SelfAttention(
config.use_memory_efficient_attention,
config.q_bucket_size,
config.k_bucket_size,
)
self.linear = nn.Linear(self.hidden_size, config.embedding_size)
def create_transformer(self):
"""Creates the Transformer backbone.
"""
self.transformer = AutoModel.from_pretrained("sentence-transformers/paraphrase-distilroberta-base-v1")
self.hidden_size = self.transformer.config.hidden_size
self.num_attention_heads = self.transformer.config.num_attention_heads
self.dim_head = self.hidden_size // self.num_attention_heads
def mean_pooling(self, token_embeddings, attention_mask):
"""Mean Pooling as described in the SBERT paper.
"""
input_mask_expanded = repeat(attention_mask, 'b l -> b l d', d=self.hidden_size).type(token_embeddings.type())
sum_embeddings = reduce(token_embeddings * input_mask_expanded, 'b l d -> b d', 'sum')
sum_mask = torch.clamp(reduce(input_mask_expanded, 'b l d -> b d', 'sum'), min=1e-9)
return sum_embeddings / sum_mask
def get_episode_embeddings(self, input_ids, attention_mask, output_attentions=False, document_batch_size=0):
"""Computes the Author Embedding.
"""
B, E, _ = attention_mask.shape
input_ids = rearrange(input_ids, 'b e l -> (b e) l')
attention_mask = rearrange(attention_mask, 'b e l -> (b e) l')
if document_batch_size > 0:
outputs = {"last_hidden_state": [], "attentions": []}
for i in range(0, len(input_ids), document_batch_size):
out = self.transformer(
input_ids=input_ids[i:i+document_batch_size],
attention_mask=attention_mask[i:i+document_batch_size],
return_dict=True,
output_hidden_states=False,
output_attentions=output_attentions,
)
outputs["last_hidden_state"].append(out["last_hidden_state"])
if output_attentions:
outputs["attentions"].append(out["attentions"])
outputs["last_hidden_state"] = torch.cat(outputs["last_hidden_state"], dim=0)
if output_attentions:
outputs["attentions"] = tuple([torch.cat([x[i] for x in outputs["attentions"]], dim=0) for i in range(len(outputs["attentions"][0]))])
else:
outputs = self.transformer(
input_ids=input_ids,
attention_mask=attention_mask,
return_dict=True,
output_hidden_states=False,
output_attentions=output_attentions,
)
# at this point, we're embedding individual "comments"
comment_embeddings = self.mean_pooling(outputs['last_hidden_state'], attention_mask)
comment_embeddings = rearrange(comment_embeddings, '(b e) l -> b e l', b=B, e=E)
# aggregate individual comments embeddings into episode embeddings
episode_embeddings = self.attn_fn(comment_embeddings, comment_embeddings, comment_embeddings)
episode_embeddings = reduce(episode_embeddings, 'b e l -> b l', 'max')
episode_embeddings = self.linear(episode_embeddings)
if output_attentions:
return episode_embeddings, outputs["attentions"]
return episode_embeddings
def forward(self, input_ids, attention_mask, output_attentions=False, document_batch_size=0):
"""Calculates a fixed-length feature vector for a batch of episode samples.
"""
output = self.get_episode_embeddings(input_ids, attention_mask, output_attentions, document_batch_size)
return output