Spaces:
				
			
			
	
			
			
		Sleeping
		
	
	
	
			
			
	
	
	
	
		
		
		Sleeping
		
	| from typing import Tuple | |
| import torch | |
| import torch.nn as nn | |
| HIDDEN_SIZE = 64 | |
| VOCAB_SIZE = 196906 | |
| EMBEDDING_DIM = 64 # embedding_dim | |
| SEQ_LEN = 100 | |
| BATCH_SIZE = 16 | |
| class BahdanauAttention(nn.Module): | |
| def __init__(self, hidden_size: int = HIDDEN_SIZE) -> None: | |
| super().__init__() | |
| self.hidden_size = hidden_size | |
| self.W_q = nn.Linear(hidden_size, hidden_size) | |
| self.W_k = nn.Linear(hidden_size, hidden_size) | |
| self.W_v = nn.Linear(hidden_size, 1) | |
| self.tanh = nn.Tanh() | |
| def forward( | |
| self, | |
| lstm_outputs: torch.Tensor, # BATCH_SIZE x SEQ_LEN x HIDDEN_SIZE | |
| final_hidden: torch.Tensor, # BATCH_SIZE x HIDDEN_SIZE | |
| ) -> Tuple[torch.Tensor, torch.Tensor]: | |
| keys = self.W_k(lstm_outputs) | |
| query = self.W_q(final_hidden) | |
| sum = query.unsqueeze(1) + keys | |
| tanhed = self.tanh(sum) | |
| vector = self.W_v(tanhed).squeeze(-1) | |
| att_weights = torch.softmax(vector, -1) | |
| context = torch.bmm(att_weights.unsqueeze(1), keys).squeeze() | |
| return context, att_weights | |
| BahdanauAttention()( | |
| torch.randn(BATCH_SIZE, SEQ_LEN, HIDDEN_SIZE), torch.randn(BATCH_SIZE, HIDDEN_SIZE) | |
| )[1].shape | |
| class LSTMConcatAttentionEmbed(nn.Module): | |
| def __init__(self) -> None: | |
| super().__init__() | |
| self.embedding = nn.Embedding(VOCAB_SIZE, EMBEDDING_DIM) | |
| # self.embedding = embedding_layer | |
| self.lstm = nn.LSTM(EMBEDDING_DIM, HIDDEN_SIZE, batch_first=True) | |
| self.attn = BahdanauAttention(HIDDEN_SIZE) | |
| self.clf = nn.Sequential( | |
| nn.Linear(HIDDEN_SIZE, 128), | |
| nn.Dropout(), | |
| nn.Tanh(), | |
| nn.Linear(128, 64), | |
| nn.Dropout(), | |
| nn.Tanh(), | |
| nn.Linear(64, 1), | |
| ) | |
| def forward(self, x): | |
| embeddings = self.embedding(x) | |
| outputs, (h_n, _) = self.lstm(embeddings) | |
| att_hidden, att_weights = self.attn(outputs, h_n.squeeze(0)) | |
| out = self.clf(att_hidden) | |
| return out, att_weights | |