asigalov61 commited on
Commit
5b360d1
·
1 Parent(s): 3dd0e68

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +17 -23
app.py CHANGED
@@ -25,7 +25,7 @@ def generate(
25
  temperature = 0.9,
26
  verbose=False,
27
  return_prime=False,
28
- ):
29
 
30
  out = torch.LongTensor([start_tokens])
31
 
@@ -34,29 +34,23 @@ def generate(
34
  if verbose:
35
  print("Generating sequence of max length:", seq_len)
36
 
37
- max_len = seq_len
38
- cur_len = 0
39
-
40
- bar = tqdm.tqdm(desc="generating", total=max_len - cur_len, disable=in_space)
41
- with bar:
42
- while cur_len < max_len:
43
-
44
- x = out[:, -max_seq_len:]
45
-
46
- torch_in = x.tolist()[0]
47
-
48
- logits = torch.FloatTensor(session.run(None, {'input': [torch_in]})[0])[:, -1]
49
-
50
- filtered_logits = logits
51
-
52
- probs = F.softmax(filtered_logits / temperature, dim=-1)
53
-
54
- sample = torch.multinomial(probs, 1)
55
 
56
- out = torch.cat((out, sample), dim=-1)
57
-
58
- cur_len += 1
59
- bar.update(1)
 
 
 
 
 
 
 
 
 
60
 
61
  if return_prime:
62
  return out[:, :]
 
25
  temperature = 0.9,
26
  verbose=False,
27
  return_prime=False,
28
+ progress=gr.Progress()):
29
 
30
  out = torch.LongTensor([start_tokens])
31
 
 
34
  if verbose:
35
  print("Generating sequence of max length:", seq_len)
36
 
37
+ progress(0, desc="Starting...")
38
+
39
+ for i in progress.tqdm(range(seq_len)):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
 
41
+ x = out[:, -max_seq_len:]
42
+
43
+ torch_in = x.tolist()[0]
44
+
45
+ logits = torch.FloatTensor(session.run(None, {'input': [torch_in]})[0])[:, -1]
46
+
47
+ filtered_logits = logits
48
+
49
+ probs = F.softmax(filtered_logits / temperature, dim=-1)
50
+
51
+ sample = torch.multinomial(probs, 1)
52
+
53
+ out = torch.cat((out, sample), dim=-1)
54
 
55
  if return_prime:
56
  return out[:, :]