File size: 3,825 Bytes
9afee78
 
 
 
 
 
 
 
 
 
 
 
 
5d0546f
6173606
 
9afee78
6173606
 
 
 
9afee78
 
5d0546f
 
9afee78
ef7e93f
 
 
9afee78
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
import os
from functools import partial

import torch
import torch.nn as nn
from einops import rearrange, reduce, repeat
from transformers import AutoModel
import torch.nn.functional as F

# from models.layers import MemoryEfficientAttention, SelfAttention
from huggingface_hub import PyTorchModelHubMixin

from transformers import AutoModel, PreTrainedModel
from .config import LUARConfig
from huggingface_hub import PyTorchModelHubMixin


class UARScene(
    nn.Module,
    PyTorchModelHubMixin,
              ):
    """Defines the SBERT model.
    """
    config_class = LUARConfig

    def __init__(self, config):
        
        super().__init__()
        self.config = config
        self.create_transformer()
        self.linear = nn.Linear(self.hidden_size, self.config.embedding_size)
    
    def attn_fn(self, k, q ,v) : 
        d_k = q.size(-1)
        scores = torch.matmul(k, q.transpose(-2, -1)) / math.sqrt(d_k)
        p_attn = F.softmax(scores, dim=-1)

        return torch.matmul(p_attn, v)
    
    
    def create_transformer(self):
        """Creates the Transformer model.
        """
        
        self.transformer = AutoModel.from_pretrained("sentence-transformers/all-distilroberta-v1")
        self.hidden_size = self.transformer.config.hidden_size
        self.num_attention_heads = self.transformer.config.num_attention_heads
        self.dim_head = self.hidden_size // self.num_attention_heads
        
    def mean_pooling(self, token_embeddings, attention_mask):
        """Mean Pooling as described in the SBERT paper.
        """
        input_mask_expanded = repeat(attention_mask, 'b l -> b l d', d=self.hidden_size).float()
        sum_embeddings = reduce(token_embeddings * input_mask_expanded, 'b l d -> b d', 'sum')
        sum_mask = torch.clamp(reduce(input_mask_expanded, 'b l d -> b d', 'sum'), min=1e-9)
        return sum_embeddings / sum_mask
    
    def get_episode_embeddings(self, data):
        """Computes the Author Embedding. 
        """
        # batch_size, num_sample_per_author, episode_length
        input_ids, attention_mask = data[0].unsqueeze(1), data[1].unsqueeze(1)
        B, N, E, _ = input_ids.shape
        
        input_ids = rearrange(input_ids, 'b n e l -> (b n e) l')
        attention_mask = rearrange(attention_mask, 'b n e l -> (b n e) l')
        
        outputs = self.transformer(
            input_ids=input_ids,
            attention_mask=attention_mask,
            return_dict=True,
            output_hidden_states=True
        )

        # at this point, we're embedding individual "comments"
        comment_embeddings = self.mean_pooling(outputs['last_hidden_state'], attention_mask)
        comment_embeddings = rearrange(comment_embeddings, '(b n e) l -> (b n) e l', b=B, n=N, e=E)

        # aggregate individual comments embeddings into episode embeddings
        episode_embeddings = self.attn_fn(comment_embeddings, comment_embeddings, comment_embeddings)
        episode_embeddings = reduce(episode_embeddings, 'b e l -> b l', 'max')
        
        episode_embeddings = self.linear(episode_embeddings)
        return episode_embeddings, comment_embeddings
    
    def forward(self, input_ids, attention_mask):
        """Calculates a fixed-length feature vector for a batch of episode samples.
        """
        data = [input_ids, attention_mask]
        episode_embeddings,_ = self.get_episode_embeddings(data)

        return episode_embeddings

    def _model_forward(self, batch):
        """Passes a batch of data through the model. 
           This is used in the lightning_trainer.py file.
        """
        data, _, _ = batch
        episode_embeddings, comment_embeddings = self.forward(data)
        # labels = torch.flatten(labels)
                
        return episode_embeddings, comment_embeddings