hqsiswiliam commited on
Commit
2d000c7
1 Parent(s): 7254ba9

Update perplexity.py

Browse files

Add model & tokenizer parameters, to avoid reinitialising the model every time on _compute()

Files changed (1) hide show
  1. perplexity.py +8 -6
perplexity.py CHANGED
@@ -49,6 +49,8 @@ Args:
49
  add_start_token (bool): whether to add the start token to the texts,
50
  so the perplexity can include the probability of the first word. Defaults to True.
51
  device (str): device to run on, defaults to 'cuda' when available
 
 
52
  Returns:
53
  perplexity: dictionary containing the perplexity scores for the texts
54
  in the input list, as well as the mean perplexity. If one of the input texts is
@@ -101,7 +103,7 @@ class Perplexity(evaluate.Metric):
101
  )
102
 
103
  def _compute(
104
- self, predictions, model_id, batch_size: int = 16, add_start_token: bool = True, device=None, max_length=None
105
  ):
106
 
107
  if device is not None:
@@ -110,11 +112,11 @@ class Perplexity(evaluate.Metric):
110
  device = "cuda"
111
  else:
112
  device = "cuda" if torch.cuda.is_available() else "cpu"
113
-
114
- model = AutoModelForCausalLM.from_pretrained(model_id)
115
- model = model.to(device)
116
-
117
- tokenizer = AutoTokenizer.from_pretrained(model_id)
118
 
119
  # if batch_size > 1 (which generally leads to padding being required), and
120
  # if there is not an already assigned pad_token, assign an existing
 
49
  add_start_token (bool): whether to add the start token to the texts,
50
  so the perplexity can include the probability of the first word. Defaults to True.
51
  device (str): device to run on, defaults to 'cuda' when available
52
+ model (AutoModelForCausalLM): the model for calculating Perplexity, if provided, the model won't initialized from model_id
53
+ tokenizer (AutoTokenizer): the tokenizer for calculating Perplexity, if provided, the tokenizer won't initialized from model_id
54
  Returns:
55
  perplexity: dictionary containing the perplexity scores for the texts
56
  in the input list, as well as the mean perplexity. If one of the input texts is
 
103
  )
104
 
105
  def _compute(
106
+ self, predictions, model_id, batch_size: int = 16, add_start_token: bool = True, device=None, max_length=None, model = None, tokenizer = None
107
  ):
108
 
109
  if device is not None:
 
112
  device = "cuda"
113
  else:
114
  device = "cuda" if torch.cuda.is_available() else "cpu"
115
+ if model is None:
116
+ model = AutoModelForCausalLM.from_pretrained(model_id)
117
+ model = model.to(device)
118
+ if tokenizer is None:
119
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
120
 
121
  # if batch_size > 1 (which generally leads to padding being required), and
122
  # if there is not an already assigned pad_token, assign an existing