Балаганский Никита Николаевич commited on
Commit
e852933
1 Parent(s): 570bb74

everything wrapped in cache

Browse files
Files changed (3) hide show
  1. generator.py +3 -3
  2. sampling.py +3 -3
  3. utils.py +13 -0
generator.py CHANGED
@@ -6,6 +6,8 @@ import streamlit as st
6
 
7
  from plotly import graph_objects as go
8
 
 
 
9
 
10
  class Generator:
11
  def __init__(self, lm_model_name, device, entropy=None):
@@ -15,9 +17,7 @@ class Generator:
15
  self.tokenizer = transformers.AutoTokenizer.from_pretrained(
16
  lm_model_name
17
  )
18
- self.lm = transformers.AutoModelForCausalLM.from_pretrained(
19
- lm_model_name
20
- ).to(device)
21
  self.lm.eval()
22
 
23
  self.lm.config.pad_token_id = self.lm.config.eos_token_id
6
 
7
  from plotly import graph_objects as go
8
 
9
+ from utils import get_lm
10
+
11
 
12
  class Generator:
13
  def __init__(self, lm_model_name, device, entropy=None):
17
  self.tokenizer = transformers.AutoTokenizer.from_pretrained(
18
  lm_model_name
19
  )
20
+ self.lm = get_lm(lm_model_name).to(device)
 
 
21
  self.lm.eval()
22
 
23
  self.lm.config.pad_token_id = self.lm.config.eos_token_id
sampling.py CHANGED
@@ -7,6 +7,8 @@ from torch.nn import functional as F
7
 
8
  import transformers
9
 
 
 
10
 
11
  def sample_from_values(unscaled_probs, values):
12
  samples = torch.multinomial(unscaled_probs, 1)
@@ -37,9 +39,7 @@ class CAIFSampler:
37
  classifier_name
38
  )
39
  self.classifier_model = (
40
- transformers.AutoModelForSequenceClassification.from_pretrained(
41
- classifier_name
42
- ).to(device)
43
  )
44
  self.classifier_model.eval()
45
  self.lm_tokenizer = lm_tokenizer
7
 
8
  import transformers
9
 
10
+ from utils import get_cls
11
+
12
 
13
  def sample_from_values(unscaled_probs, values):
14
  samples = torch.multinomial(unscaled_probs, 1)
39
  classifier_name
40
  )
41
  self.classifier_model = (
42
+ get_cls(classifier_name).to(device)
 
 
43
  )
44
  self.classifier_model.eval()
45
  self.lm_tokenizer = lm_tokenizer
utils.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import lru_cache
2
+
3
+ from transformers import AutoModelForCausalLM, AutoModelForSequenceClassification
4
+
5
+
6
+ @lru_cache(3)
7
+ def get_lm(lm_name):
8
+ return AutoModelForCausalLM.from_pretrained(lm_name)
9
+
10
+
11
+ @lru_cache(3)
12
+ def get_cls(cls_name):
13
+ return AutoModelForSequenceClassification.from_pretrained(cls_name)