xu1998hz commited on
Commit
a45b5e7
1 Parent(s): 127b211

fix roberta on sescore.py

Browse files
Files changed (2) hide show
  1. __init__.py +0 -37
  2. sescore.py +39 -1
__init__.py CHANGED
@@ -1,38 +1 @@
1
- import comet
2
- from typing import Dict
3
- import torch
4
- from comet.encoders.base import Encoder
5
- from comet.encoders.bert import BERTEncoder
6
- from transformers import AutoModel, AutoTokenizer
7
 
8
- class robertaEncoder(BERTEncoder):
9
- def __init__(self, pretrained_model: str) -> None:
10
- super(Encoder, self).__init__()
11
- self.tokenizer = AutoTokenizer.from_pretrained(pretrained_model)
12
- self.model = AutoModel.from_pretrained(
13
- pretrained_model, add_pooling_layer=False
14
- )
15
- self.model.encoder.output_hidden_states = True
16
-
17
- @classmethod
18
- def from_pretrained(cls, pretrained_model: str) -> Encoder:
19
- return robertaEncoder(pretrained_model)
20
-
21
- def forward(
22
- self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwargs
23
- ) -> Dict[str, torch.Tensor]:
24
- last_hidden_states, _, all_layers = self.model(
25
- input_ids=input_ids,
26
- attention_mask=attention_mask,
27
- output_hidden_states=True,
28
- return_dict=False,
29
- )
30
- return {
31
- "sentemb": last_hidden_states[:, 0, :],
32
- "wordemb": last_hidden_states,
33
- "all_layers": all_layers,
34
- "attention_mask": attention_mask,
35
- }
36
-
37
- # initialize roberta into str2encoder
38
- comet.encoders.str2encoder['RoBERTa'] = robertaEncoder
 
 
 
 
 
 
 
1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
sescore.py CHANGED
@@ -16,6 +16,42 @@
16
  import evaluate
17
  import datasets
18
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
 
20
  # TODO: Add BibTeX citation
21
  _CITATION = """\
@@ -87,12 +123,14 @@ class SEScore(evaluate.Metric):
87
  from comet import load_from_checkpoint
88
  import gdown
89
  import os
 
 
90
  url = "https://drive.google.com/uc?id=1QgMP_Y4QCbvDMTeVacYt0J76OYvwWK9V&export=download&confirm=true"
91
  output = 'sescore_ckpt.gz'
92
  gdown.download(url, output, quiet=False)
93
  cmd = 'tar -xvf sescore_ckpt.gz'
94
  os.system(cmd)
95
- self.scorer = load_from_checkpoint('sescore_ckpt/zh_en/checkpoint/sescore_english.ckpt')
96
 
97
  def _compute(self, sources, predictions, references, gpus=None, progress_bar=False):
98
  if gpus is None:
 
16
  import evaluate
17
  import datasets
18
 
19
+ import comet
20
+ from typing import Dict
21
+ import torch
22
+ from comet.encoders.base import Encoder
23
+ from comet.encoders.bert import BERTEncoder
24
+ from transformers import AutoModel, AutoTokenizer
25
+
26
+ class robertaEncoder(BERTEncoder):
27
+ def __init__(self, pretrained_model: str) -> None:
28
+ super(Encoder, self).__init__()
29
+ self.tokenizer = AutoTokenizer.from_pretrained(pretrained_model)
30
+ self.model = AutoModel.from_pretrained(
31
+ pretrained_model, add_pooling_layer=False
32
+ )
33
+ self.model.encoder.output_hidden_states = True
34
+
35
+ @classmethod
36
+ def from_pretrained(cls, pretrained_model: str) -> Encoder:
37
+ return robertaEncoder(pretrained_model)
38
+
39
+ def forward(
40
+ self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwargs
41
+ ) -> Dict[str, torch.Tensor]:
42
+ last_hidden_states, _, all_layers = self.model(
43
+ input_ids=input_ids,
44
+ attention_mask=attention_mask,
45
+ output_hidden_states=True,
46
+ return_dict=False,
47
+ )
48
+ return {
49
+ "sentemb": last_hidden_states[:, 0, :],
50
+ "wordemb": last_hidden_states,
51
+ "all_layers": all_layers,
52
+ "attention_mask": attention_mask,
53
+ }
54
+
55
 
56
  # TODO: Add BibTeX citation
57
  _CITATION = """\
 
123
  from comet import load_from_checkpoint
124
  import gdown
125
  import os
126
+ # initialize roberta into str2encoder
127
+ comet.encoders.str2encoder['RoBERTa'] = robertaEncoder
128
  url = "https://drive.google.com/uc?id=1QgMP_Y4QCbvDMTeVacYt0J76OYvwWK9V&export=download&confirm=true"
129
  output = 'sescore_ckpt.gz'
130
  gdown.download(url, output, quiet=False)
131
  cmd = 'tar -xvf sescore_ckpt.gz'
132
  os.system(cmd)
133
+ self.scorer = load_from_checkpoint('/home/user/app/sescore_ckpt/zh_en/checkpoint/sescore_english.ckpt')
134
 
135
  def _compute(self, sources, predictions, references, gpus=None, progress_bar=False):
136
  if gpus is None: