File size: 3,825 Bytes
9afee78 5d0546f 6173606 9afee78 6173606 9afee78 5d0546f 9afee78 ef7e93f 9afee78 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 |
import os
from functools import partial
import torch
import torch.nn as nn
from einops import rearrange, reduce, repeat
from transformers import AutoModel
import torch.nn.functional as F
# from models.layers import MemoryEfficientAttention, SelfAttention
from huggingface_hub import PyTorchModelHubMixin
from transformers import AutoModel, PreTrainedModel
from .config import LUARConfig
from huggingface_hub import PyTorchModelHubMixin
class UARScene(
nn.Module,
PyTorchModelHubMixin,
):
"""Defines the SBERT model.
"""
config_class = LUARConfig
def __init__(self, config):
super().__init__()
self.config = config
self.create_transformer()
self.linear = nn.Linear(self.hidden_size, self.config.embedding_size)
def attn_fn(self, k, q ,v) :
d_k = q.size(-1)
scores = torch.matmul(k, q.transpose(-2, -1)) / math.sqrt(d_k)
p_attn = F.softmax(scores, dim=-1)
return torch.matmul(p_attn, v)
def create_transformer(self):
"""Creates the Transformer model.
"""
self.transformer = AutoModel.from_pretrained("sentence-transformers/all-distilroberta-v1")
self.hidden_size = self.transformer.config.hidden_size
self.num_attention_heads = self.transformer.config.num_attention_heads
self.dim_head = self.hidden_size // self.num_attention_heads
def mean_pooling(self, token_embeddings, attention_mask):
"""Mean Pooling as described in the SBERT paper.
"""
input_mask_expanded = repeat(attention_mask, 'b l -> b l d', d=self.hidden_size).float()
sum_embeddings = reduce(token_embeddings * input_mask_expanded, 'b l d -> b d', 'sum')
sum_mask = torch.clamp(reduce(input_mask_expanded, 'b l d -> b d', 'sum'), min=1e-9)
return sum_embeddings / sum_mask
def get_episode_embeddings(self, data):
"""Computes the Author Embedding.
"""
# batch_size, num_sample_per_author, episode_length
input_ids, attention_mask = data[0].unsqueeze(1), data[1].unsqueeze(1)
B, N, E, _ = input_ids.shape
input_ids = rearrange(input_ids, 'b n e l -> (b n e) l')
attention_mask = rearrange(attention_mask, 'b n e l -> (b n e) l')
outputs = self.transformer(
input_ids=input_ids,
attention_mask=attention_mask,
return_dict=True,
output_hidden_states=True
)
# at this point, we're embedding individual "comments"
comment_embeddings = self.mean_pooling(outputs['last_hidden_state'], attention_mask)
comment_embeddings = rearrange(comment_embeddings, '(b n e) l -> (b n) e l', b=B, n=N, e=E)
# aggregate individual comments embeddings into episode embeddings
episode_embeddings = self.attn_fn(comment_embeddings, comment_embeddings, comment_embeddings)
episode_embeddings = reduce(episode_embeddings, 'b e l -> b l', 'max')
episode_embeddings = self.linear(episode_embeddings)
return episode_embeddings, comment_embeddings
def forward(self, input_ids, attention_mask):
"""Calculates a fixed-length feature vector for a batch of episode samples.
"""
data = [input_ids, attention_mask]
episode_embeddings,_ = self.get_episode_embeddings(data)
return episode_embeddings
def _model_forward(self, batch):
"""Passes a batch of data through the model.
This is used in the lightning_trainer.py file.
"""
data, _, _ = batch
episode_embeddings, comment_embeddings = self.forward(data)
# labels = torch.flatten(labels)
return episode_embeddings, comment_embeddings
|