Spaces:
Paused
Paused
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from . import encode, decode | |
| from typing import Union, List | |
| class Shakkelha(nn.Module): | |
| def __init__(self, | |
| dim_input: int=91, | |
| dim_output: int=19, | |
| sd_path: str=None): | |
| super().__init__() | |
| self.emb0 = nn.Embedding(dim_input, 25) | |
| self.lstm0 = nn.LSTM(25, 256, batch_first=True, bidirectional=True) | |
| self.lstm1 = nn.LSTM(512, 256, batch_first=True, bidirectional=True) | |
| self.dropout = nn.Dropout(p=0.5) | |
| self.dense0 = nn.Linear(512, 512) | |
| self.dense1 = nn.Linear(512, 512) | |
| self.dense2 = nn.Linear(512, dim_output) | |
| self.eval() | |
| if sd_path is not None: | |
| self.load_state_dict(torch.load(sd_path)) | |
| def forward(self, x: torch.Tensor): | |
| x = self.emb0(x) | |
| x, _ = self.lstm0(x) | |
| x = self.dropout(x) | |
| x, _ = self.lstm1(x) | |
| x = self.dropout(x) | |
| x = F.relu(self.dense0(x)) | |
| x = F.relu(self.dense1(x)) | |
| x = F.softmax(self.dense2(x), dim=-1) | |
| return x | |
| def infer(self, x: torch.Tensor): | |
| return self.forward(x) | |
| def _predict_list(self, input_list: List[str], return_probs: bool=False): | |
| output_list = [] | |
| probs_list = [] | |
| for input_text in input_list: | |
| if return_probs: | |
| output_text, probs = self._predict_single(input_text, return_probs=True) | |
| output_list.append(output_text) | |
| probs_list.append(probs) | |
| else: | |
| output_list.append(self._predict_single(input_text)) | |
| if return_probs: | |
| return output_list, probs_list | |
| return output_list | |
| def _predict_single(self, input_text: str, return_probs: bool=False): | |
| ids = encode(input_text) | |
| input = torch.LongTensor(ids)[None].to(self.emb0.weight.device) | |
| probs = self.infer(input).cpu() | |
| output = decode(probs, input_text) | |
| if return_probs: | |
| return output, probs | |
| return output | |
| def predict(self, input: Union[str, List[str]], return_probs: bool=False): | |
| if isinstance(input, str): | |
| return self._predict_single(input, return_probs=return_probs) | |
| return self._predict_list(input, return_probs=return_probs) |