davebulaval commited on
Commit
495f38b
Β·
1 Parent(s): ba092fc

remove device handling

Browse files
Files changed (1) hide show
  1. meaningbert.py +3 -5
meaningbert.py CHANGED
@@ -67,7 +67,6 @@ MeaningBERT metric for assessing meaning preservation between sentences.
67
  Args:
68
  predictions (list of str): Predictions sentences.
69
  references (list of str): References sentences (same number of element as predictions).
70
- device (str): Device to use for model inference. By default, set to "cuda".
71
 
72
  Returns:
73
  score: the meaning score between two sentences in alist format respecting the order of the predictions and
@@ -78,7 +77,7 @@ Examples:
78
 
79
  >>> references = ["hello there", "general kenobi"]
80
  >>> predictions = ["hello there", "general kenobi"]
81
- >>> meaning_bert = evaluate.load("davebulaval/meaningbert", device="cuda:0")
82
  >>> results = meaning_bert.compute(predictions=predictions, references=references)
83
  """
84
 
@@ -113,7 +112,6 @@ class MeaningBERT(evaluate.Metric):
113
  self,
114
  predictions: List,
115
  references: List,
116
- device: str = "cuda",
117
  ) -> Dict:
118
  assert len(references) == len(
119
  predictions
@@ -125,7 +123,7 @@ class MeaningBERT(evaluate.Metric):
125
 
126
  # We load the MeaningBERT pretrained model
127
  scorer = AutoModelForSequenceClassification.from_pretrained(
128
- "davebulaval/MeaningBERT", device_map=device
129
  )
130
  scorer.eval()
131
 
@@ -140,7 +138,7 @@ class MeaningBERT(evaluate.Metric):
140
  truncation=True,
141
  padding=True,
142
  return_tensors="pt",
143
- ).to(device)
144
 
145
  with filter_logging_context():
146
  # We process the text
 
67
  Args:
68
  predictions (list of str): Predictions sentences.
69
  references (list of str): References sentences (same number of element as predictions).
 
70
 
71
  Returns:
72
  score: the meaning score between two sentences in alist format respecting the order of the predictions and
 
77
 
78
  >>> references = ["hello there", "general kenobi"]
79
  >>> predictions = ["hello there", "general kenobi"]
80
+ >>> meaning_bert = evaluate.load("davebulaval/meaningbert")
81
  >>> results = meaning_bert.compute(predictions=predictions, references=references)
82
  """
83
 
 
112
  self,
113
  predictions: List,
114
  references: List,
 
115
  ) -> Dict:
116
  assert len(references) == len(
117
  predictions
 
123
 
124
  # We load the MeaningBERT pretrained model
125
  scorer = AutoModelForSequenceClassification.from_pretrained(
126
+ "davebulaval/MeaningBERT"
127
  )
128
  scorer.eval()
129
 
 
138
  truncation=True,
139
  padding=True,
140
  return_tensors="pt",
141
+ )
142
 
143
  with filter_logging_context():
144
  # We process the text