class LSTMTextGenerator(nn.Module, PyTorchModelHubMixin): def __init__(self, input_size=45, hidden_size=512, output_size=45, num_layers=2, dropout=0.5): super(LSTMTextGenerator, self).__init__() self.embedding = nn.Embedding(input_size, hidden_size) self.lstm = nn.LSTM(hidden_size, hidden_size, num_layers, batch_first=True, dropout=dropout, bidirectional=False) self.fc = nn.Linear(hidden_size, output_size) self.num_layers = num_layers self.hidden_size = hidden_size def forward(self, x, hidden): x = x.to(torch.long) x = self.embedding(x) x, hidden = self.lstm(x, hidden) x = self.fc(x) return x, hidden def init_hidden(self, batch_size): return (torch.zeros(self.num_layers, batch_size, self.hidden_size).to(device), torch.zeros(self.num_layers, batch_size, self.hidden_size).to(device)) class PreTrainedPipeline(): def __init__(self, path=""): self.model = model = LSTMTextGenerator.from_pretrained("miittnnss/lstm-textgen-pets") def __call__(self, inputs: str): """ Args: inputs (:obj:`str`): a string containing some text Return: A :obj:`PIL.Image` with the raw image representation as PIL. """ # IMPLEMENT_THIS raise NotImplementedError( "Please implement PreTrainedPipeline __call__ function" )