rrayy
Changes to be committed: LSTM ์ด๊ธฐ ๋ฉ๋ชจ๋ฆฌ๋ฅผ x์์ ์์ฑํ ์ ์๋๋ก ์์
32bf138
| import torch.nn.functional as F | |
| from pysdtw import SoftDTW | |
| from torch import nn | |
| import torch | |
| class Vector2MIDI(nn.Module): | |
| def __init__(self, hidden_dim, X_dim=25, dropout=0.3): | |
| super().__init__() | |
| self.vocab_sizes = [101, 17, 73, 17, 17, 59, 17] # ์ค์ ๋ฐ์ดํฐ ๊ธฐ๋ฐ vocab ํฌ๊ธฐ ์ค์ | |
| self.init_hidden = nn.Linear(X_dim, hidden_dim) | |
| self.init_cell = nn.Linear(X_dim, hidden_dim) | |
| self.embeddings = nn.ModuleList([ | |
| nn.Embedding(vocab_size, hidden_dim, padding_idx=16) for vocab_size in self.vocab_sizes | |
| ]) | |
| # ๊ณผ์ ํฉ ๋ฐฉ์ง ๋๋กญ์์ LSTM | |
| self.lstm = nn.LSTM(hidden_dim * len(self.vocab_sizes), hidden_dim, num_layers=2, batch_first=True, dropout=dropout) | |
| self.output_heads = nn.ModuleList([ # ๊ฐ ์ฐจ์๋ณ ๋ ๋ฆฝ์ ์ธ ์ถ๋ ฅ ํค๋ | |
| nn.Linear(hidden_dim, vocab_size) for vocab_size in self.vocab_sizes | |
| ]) | |
| self.start_token_heads = nn.ModuleList([ # ์ฒซ ํ ํฐ ์์ฑ์ฉ ๋ฉํฐ ํค๋ | |
| nn.Linear(X_dim, vocab_size) for vocab_size in self.vocab_sizes | |
| ]) | |
| def forward_hnc(self, x): | |
| """hidden๊ณผ cell state ์์ฑ""" | |
| h0 = torch.tanh(self.init_hidden(x)) # ํ์ฑํ ํจ์ ์ถ๊ฐ (hyperbolic tangent) | |
| c0 = torch.tanh(self.init_cell(x)) | |
| h0 = h0.unsqueeze(0).repeat(2, 1, 1) # (num_layers, B, H) | |
| c0 = c0.unsqueeze(0).repeat(2, 1, 1) | |
| return h0, c0 | |
| def forward_extend(self, y:torch.Tensor, h=None, c=None): | |
| """ | |
| y: (B, T, 7) - 7์ฐจ์ ์ ์ ํ ํฐ (EOS + ํจ๋ฉ) | |
| """ | |
| emb_list = [] | |
| for idx, emb_f in enumerate(self.embeddings): | |
| emb_list.append(emb_f(y[:, :, idx])) # [B, T, 1] | |
| emb = torch.cat(emb_list, dim=-1) # [B, T, 7] | |
| if h is not None and c is not None: | |
| out, (h, c) = self.lstm(emb, (h, c)) | |
| else: | |
| out, (h, c) = self.lstm(emb) | |
| output = [head(out) for head in self.output_heads] # list of [B, T, V_i] | |
| return output, (h, c) | |
| def forward_first(self, x:torch.Tensor): | |
| """x: 25์ฐจ์ ์คํ์ผ ๋ฒกํฐ""" | |
| logits_list = [] | |
| for head in self.start_token_heads: | |
| logits = head(x) # (B, vocab_size_i) | |
| logits_list.append(logits) | |
| return logits_list # List of 7 tensors, each (B, vocab_size_i) | |
| def calc_loss(self, style_vec:torch.Tensor, seq:torch.Tensor): | |
| """ | |
| style_vec: (B, 25) | |
| seq: (B, T, 7) | |
| """ | |
| is_cuda = style_vec.device.type == "cuda" # ์ฟ ๋ค ์ฌ์ฉ ์ฌ๋ถ | |
| # ์์ ํ ํฐ loss (cross-entropy) | |
| logits_list = self.forward_first(style_vec) # list of 7 tensors | |
| target_first = seq[:, 0, :] # (B, 7), ์ ๋ต ํด๋์ค ์ธ๋ฑ์ค | |
| first_loss = 0 | |
| for i in range(7): | |
| logits_i = logits_list[i] # (B, vocab_size_i) | |
| target_i = target_first[:, i].long() # (B,) | |
| first_loss += F.cross_entropy(logits_i, target_i) | |
| first_loss /= 7 | |
| # hidden cell state ์์ธก | |
| (h,c) = self.forward_hnc(style_vec) | |
| # ์ํ์ค ํ์ฅ loss (cross-entropy) | |
| pred_logits, _ = self.forward_extend(seq[:, :-1, :], h, c) | |
| target_seq = seq[:, 1:, :] | |
| extend_loss = 0 | |
| pred_tokens = [] | |
| for i in range(7): | |
| extend_loss += F.cross_entropy( | |
| pred_logits[i].reshape(-1, pred_logits[i].size(-1)).float(), | |
| target_seq[:, :, i].reshape(-1), | |
| ignore_index=16 | |
| ) | |
| # argmax๋ก ์์ธก ํ ํฐ ์ถ์ถ (Soft-DTW์ฉ) | |
| pred_tokens.append(pred_logits[i].argmax(-1, keepdim=True)) # (B, T-1, 1) | |
| pred_tokens = torch.cat(pred_tokens, dim=-1) # list -> Tensor (B, T-1, 7) | |
| extend_loss /= 7 | |
| # Soft-DTW loss | |
| soft_dtw = SoftDTW(use_cuda=is_cuda) | |
| min_len = min(pred_tokens.shape[1], target_seq.shape[1]) | |
| sdtw_loss = soft_dtw(pred_tokens[:, :min_len, :].float(), target_seq[:, :min_len, :].float()).mean() | |
| sdtw_loss = torch.nan_to_num(torch.log1p(sdtw_loss)) # log์ x๊ฐ์ด ๋๋ฌด ์์ผ๋ฉด nan ๋ฐ์ํ๋ ๊ฒ ๊ฐ์ | |
| return first_loss + extend_loss + 0.3 * sdtw_loss | |
| def _top_k_sampling(self, logits, top_k=5, temperature=1.0): | |
| logits = logits / temperature | |
| topk_vals, topk_idx = torch.topk(logits, top_k, dim=-1) | |
| probs = F.softmax(topk_vals, dim=-1) | |
| idx = torch.multinomial(probs, 1).squeeze(-1) | |
| return topk_idx.gather(-1, idx.unsqueeze(-1)).squeeze(-1) | |
| def generate(self, x:torch.Tensor, max_len=128, top_k=5): #TODO: ์คํํธ ํ ํฐ ๊ทธ๋ฅ ํ ํฐ์ผ๋ก ๋ฐ๊พธ๊ณ ๊ฑฐ๊ธฐ์ ๊ณ์ autogressive๋ก ๋ค์ ํ ํฐ ๋ฃ๊ธฐ | |
| """x: 25์ฐจ์ ์คํ์ผ ๋ฒกํฐ""" | |
| self.eval() | |
| batch_size = x.size(0) | |
| h, c = None, None | |
| start_tokens = torch.zeros(batch_size, 1, 7, dtype=torch.int64, device=x.device) | |
| for i, head in enumerate(self.start_token_heads): | |
| logits = head(x) # (B, vocab_size_i) | |
| # ์คํ์ผ ๊ธฐ๋ฐ ์ฒซ ํ ํฐ ์ํ๋ง | |
| if i in [2, 5, 7]: # duration ์ฐจ์: ๋ ํ์ ์ ์ผ๋ก | |
| probs = F.softmax(logits / 0.5, dim=-1) # ๋ฎ์ ์จ๋ | |
| else: # pitch, velocity ๋ฑ: ๋ค์์ฑ ํ์ฉ | |
| probs = F.softmax(logits / 1.2, dim=-1) # ์ฝ๊ฐ ๋์ ์จ๋ | |
| token = torch.multinomial(probs, num_samples=1) # (B, 1) | |
| start_tokens[:, :, i] = token | |
| generated = [start_tokens.squeeze(0).squeeze(0).tolist()] | |
| for _ in range(max_len - 1): | |
| if h is None and c is None: | |
| h = self.init_hidden(x).unsqueeze(0).repeat(self.lstm.num_layers, 1, 1) | |
| c = self.init_cell(x).unsqueeze(0).repeat(self.lstm.num_layers, 1, 1) | |
| logits, (h, c) = self.forward_extend(start_tokens, h, c) # logits: list of [B, T, V_i] | |
| last_logits = [log[:, -1, :] for log in logits] # ๋ง์ง๋ง step | |
| sampled = [] | |
| for i, logit in enumerate(last_logits): | |
| if i in [2, 5, 7]: # duration ์ฐจ์: ๋ ํ์ ์ ์ผ๋ก | |
| token = self._top_k_sampling(logit, top_k=top_k, temperature=0.5) # ๋ฎ์ ์จ๋ | |
| else: # pitch, velocity ๋ฑ: ๋ค์์ฑ ํ์ฉ | |
| token = self._top_k_sampling(logit, top_k=top_k, temperature=1.2) # ์ฝ๊ฐ ๋์ ์จ๋ | |
| sampled.append(token.item()) | |
| if sampled == [100, 15, 72, 14, 15, 58, 15]: # EOS ํ ํฐ | |
| break | |
| else: | |
| generated.append(sampled) | |
| start_tokens = torch.tensor([[sampled]], device=x.device) # [1,1,7] | |
| return torch.tensor(generated, device=x.device, dtype=torch.long) # (max_len, 7) |