Spaces:
Sleeping
Sleeping
tidy
Browse files
app.py
CHANGED
@@ -41,7 +41,7 @@ def generate_step(out: object,
|
|
41 |
|
42 |
args:
|
43 |
- out (torch.Tensor): tensor of logits of size batch_size x seq_len x vocab_size
|
44 |
-
- gen_idx (int): location for which to generate
|
45 |
- top_k (int): if >0, only sample from the top k most probable words
|
46 |
- temperature (float): sampling temperature
|
47 |
- typical_p (float): if >0 use typical sampling
|
@@ -53,13 +53,13 @@ def generate_step(out: object,
|
|
53 |
logits = out.logits[:, gen_idx]
|
54 |
warpers = LogitsProcessorList()
|
55 |
if temperature:
|
56 |
-
warpers
|
57 |
if top_k > 0:
|
58 |
-
warpers
|
59 |
if typical_p > 0:
|
60 |
if typical_p >= 1:
|
61 |
typical_p = 0.999
|
62 |
-
warpers
|
63 |
logits = warpers(None, logits)
|
64 |
|
65 |
if sample:
|
|
|
41 |
|
42 |
args:
|
43 |
- out (torch.Tensor): tensor of logits of size batch_size x seq_len x vocab_size
|
44 |
+
- gen_idx (int): location for which to generate
|
45 |
- top_k (int): if >0, only sample from the top k most probable words
|
46 |
- temperature (float): sampling temperature
|
47 |
- typical_p (float): if >0 use typical sampling
|
|
|
53 |
logits = out.logits[:, gen_idx]
|
54 |
warpers = LogitsProcessorList()
|
55 |
if temperature:
|
56 |
+
warpers.append(TemperatureLogitsWarper(temperature))
|
57 |
if top_k > 0:
|
58 |
+
warpers.append(TopKLogitsWarper(top_k))
|
59 |
if typical_p > 0:
|
60 |
if typical_p >= 1:
|
61 |
typical_p = 0.999
|
62 |
+
warpers.append(TypicalLogitsWarper(typical_p))
|
63 |
logits = warpers(None, logits)
|
64 |
|
65 |
if sample:
|