miittnnss commited on
Commit
b0d0138
1 Parent(s): 459c365

Update pipeline.py

Browse files
Files changed (1) hide show
  1. pipeline.py +16 -11
pipeline.py CHANGED
@@ -25,14 +25,19 @@ class PreTrainedPipeline():
25
  self.model = model = LSTMTextGenerator.from_pretrained("miittnnss/lstm-textgen-pets")
26
 
27
  def __call__(self, inputs: str):
28
- """
29
- Args:
30
- inputs (:obj:`str`):
31
- a string containing some text
32
- Return:
33
- A :obj:`PIL.Image` with the raw image representation as PIL.
34
- """
35
- # IMPLEMENT_THIS
36
- raise NotImplementedError(
37
- "Please implement PreTrainedPipeline __call__ function"
38
- )
 
 
 
 
 
 
25
  self.model = model = LSTMTextGenerator.from_pretrained("miittnnss/lstm-textgen-pets")
26
 
27
  def __call__(self, inputs: str):
28
+ seed_numerical_data = [char_to_index[char] for char in inputs]
29
+ with torch.no_grad():
30
+ input_sequence = torch.LongTensor([seed_numerical_data]).to(device)
31
+ hidden = model.init_hidden(1)
32
+
33
+ generated_text = inputs # Initialize generated text with seed text
34
+ temperature = 0.7 # Temperature for temperature sampling
35
+
36
+ for _ in range(500):
37
+ output, hidden = model(input_sequence, hidden)
38
+ probabilities = nn.functional.softmax(output[-1, 0] / temperature, dim=0).cpu().numpy()
39
+ predicted_index = random.choices(range(output_size), weights=probabilities, k=1)[0]
40
+ generated_text += index_to_char[predicted_index] # Append the generated character to the text
41
+ input_sequence = torch.LongTensor([[predicted_index]]).to(device)
42
+
43
+ return output