Балаганский Никита Николаевич
commited on
Commit
•
9dc62b8
1
Parent(s):
e7e873a
fix fp16
Browse files
app.py
CHANGED
@@ -7,7 +7,6 @@ import torch
|
|
7 |
|
8 |
import transformers
|
9 |
import tokenizers
|
10 |
-
from torch import autocast
|
11 |
|
12 |
from sampling import CAIFSampler, TopKWithTemperatureSampler
|
13 |
from generator import Generator
|
@@ -57,6 +56,10 @@ def inference(lm_model_name: str, cls_model_name: str, prompt: str, fp16: bool =
|
|
57 |
generator.set_caif_sampler(caif_sampler)
|
58 |
ordinary_sampler = TopKWithTemperatureSampler()
|
59 |
generator.set_ordinary_sampler(ordinary_sampler)
|
|
|
|
|
|
|
|
|
60 |
with autocast(fp16):
|
61 |
sequences, tokens = generator.sample_sequences(
|
62 |
num_samples=1,
|
|
|
7 |
|
8 |
import transformers
|
9 |
import tokenizers
|
|
|
10 |
|
11 |
from sampling import CAIFSampler, TopKWithTemperatureSampler
|
12 |
from generator import Generator
|
|
|
56 |
generator.set_caif_sampler(caif_sampler)
|
57 |
ordinary_sampler = TopKWithTemperatureSampler()
|
58 |
generator.set_ordinary_sampler(ordinary_sampler)
|
59 |
+
if device == "cpu":
|
60 |
+
autocast = torch.cpu.amp.autocast
|
61 |
+
else:
|
62 |
+
autocast = torch.cuda.amp.autocast
|
63 |
with autocast(fp16):
|
64 |
sequences, tokens = generator.sample_sequences(
|
65 |
num_samples=1,
|