Балаганский Никита Николаевич
commited on
Commit
•
e852933
1
Parent(s):
570bb74
everything wrapped in cache
Browse files- generator.py +3 -3
- sampling.py +3 -3
- utils.py +13 -0
generator.py
CHANGED
@@ -6,6 +6,8 @@ import streamlit as st
|
|
6 |
|
7 |
from plotly import graph_objects as go
|
8 |
|
|
|
|
|
9 |
|
10 |
class Generator:
|
11 |
def __init__(self, lm_model_name, device, entropy=None):
|
@@ -15,9 +17,7 @@ class Generator:
|
|
15 |
self.tokenizer = transformers.AutoTokenizer.from_pretrained(
|
16 |
lm_model_name
|
17 |
)
|
18 |
-
self.lm =
|
19 |
-
lm_model_name
|
20 |
-
).to(device)
|
21 |
self.lm.eval()
|
22 |
|
23 |
self.lm.config.pad_token_id = self.lm.config.eos_token_id
|
|
|
6 |
|
7 |
from plotly import graph_objects as go
|
8 |
|
9 |
+
from utils import get_lm
|
10 |
+
|
11 |
|
12 |
class Generator:
|
13 |
def __init__(self, lm_model_name, device, entropy=None):
|
|
|
17 |
self.tokenizer = transformers.AutoTokenizer.from_pretrained(
|
18 |
lm_model_name
|
19 |
)
|
20 |
+
self.lm = get_lm(lm_model_name).to(device)
|
|
|
|
|
21 |
self.lm.eval()
|
22 |
|
23 |
self.lm.config.pad_token_id = self.lm.config.eos_token_id
|
sampling.py
CHANGED
@@ -7,6 +7,8 @@ from torch.nn import functional as F
|
|
7 |
|
8 |
import transformers
|
9 |
|
|
|
|
|
10 |
|
11 |
def sample_from_values(unscaled_probs, values):
|
12 |
samples = torch.multinomial(unscaled_probs, 1)
|
@@ -37,9 +39,7 @@ class CAIFSampler:
|
|
37 |
classifier_name
|
38 |
)
|
39 |
self.classifier_model = (
|
40 |
-
|
41 |
-
classifier_name
|
42 |
-
).to(device)
|
43 |
)
|
44 |
self.classifier_model.eval()
|
45 |
self.lm_tokenizer = lm_tokenizer
|
|
|
7 |
|
8 |
import transformers
|
9 |
|
10 |
+
from utils import get_cls
|
11 |
+
|
12 |
|
13 |
def sample_from_values(unscaled_probs, values):
|
14 |
samples = torch.multinomial(unscaled_probs, 1)
|
|
|
39 |
classifier_name
|
40 |
)
|
41 |
self.classifier_model = (
|
42 |
+
get_cls(classifier_name).to(device)
|
|
|
|
|
43 |
)
|
44 |
self.classifier_model.eval()
|
45 |
self.lm_tokenizer = lm_tokenizer
|
utils.py
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from functools import lru_cache
|
2 |
+
|
3 |
+
from transformers import AutoModelForCausalLM, AutoModelForSequenceClassification
|
4 |
+
|
5 |
+
|
6 |
+
@lru_cache(3)
|
7 |
+
def get_lm(lm_name):
|
8 |
+
return AutoModelForCausalLM.from_pretrained(lm_name)
|
9 |
+
|
10 |
+
|
11 |
+
@lru_cache(3)
|
12 |
+
def get_cls(cls_name):
|
13 |
+
return AutoModelForSequenceClassification.from_pretrained(cls_name)
|