miittnnss commited on
Commit
459c365
1 Parent(s): f9ef14f

Update pipeline.py

Browse files
Files changed (1) hide show
  1. pipeline.py +3 -9
pipeline.py CHANGED
@@ -1,6 +1,6 @@
1
 
2
- class LSTMTextGenerator(nn.Module):
3
- def __init__(self, input_size, hidden_size, output_size, num_layers=2, dropout=0.5):
4
  super(LSTMTextGenerator, self).__init__()
5
  self.embedding = nn.Embedding(input_size, hidden_size)
6
  self.lstm = nn.LSTM(hidden_size, hidden_size, num_layers, batch_first=True, dropout=dropout, bidirectional=False)
@@ -22,13 +22,7 @@ class LSTMTextGenerator(nn.Module):
22
 
23
  class PreTrainedPipeline():
24
  def __init__(self, path=""):
25
- # IMPLEMENT_THIS
26
- # Preload all the elements you are going to need at inference.
27
- # For instance your model, processors, tokenizer that might be needed.
28
- # This function is only called once, so do all the heavy processing I/O here"""
29
- raise NotImplementedError(
30
- "Please implement PreTrainedPipeline __init__ function"
31
- )
32
 
33
  def __call__(self, inputs: str):
34
  """
 
1
 
2
+ class LSTMTextGenerator(nn.Module, PyTorchModelHubMixin):
3
+ def __init__(self, input_size=45, hidden_size=512, output_size=45, num_layers=2, dropout=0.5):
4
  super(LSTMTextGenerator, self).__init__()
5
  self.embedding = nn.Embedding(input_size, hidden_size)
6
  self.lstm = nn.LSTM(hidden_size, hidden_size, num_layers, batch_first=True, dropout=dropout, bidirectional=False)
 
22
 
23
  class PreTrainedPipeline():
24
  def __init__(self, path=""):
25
+ self.model = model = LSTMTextGenerator.from_pretrained("miittnnss/lstm-textgen-pets")
 
 
 
 
 
 
26
 
27
  def __call__(self, inputs: str):
28
  """