furrutiav commited on
Commit
ca25754
1 Parent(s): d9c22e0

Update util.py

Browse files
Files changed (1) hide show
  1. util.py +52 -3
util.py CHANGED
@@ -3,9 +3,58 @@ import torch
3
  from transformers import BertTokenizer, BertModel
4
  from huggingface_hub import hf_hub_url, cached_download
5
 
6
- def get_cls_layer():
7
- config_file_url = hf_hub_url("furrutiav/beto_coherence", filename="cls_layer.torch")
8
  value = cached_download(config_file_url)
9
  return torch.load(value, map_location=torch.device('cpu'))
10
 
11
- get_cls_layer()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  from transformers import BertTokenizer, BertModel
4
  from huggingface_hub import hf_hub_url, cached_download
5
 
6
+ def get_cls_layer(repo_id="furrutiav/beto_coherence"):
7
+ config_file_url = hf_hub_url(repo_id, filename="cls_layer.torch")
8
  value = cached_download(config_file_url)
9
  return torch.load(value, map_location=torch.device('cpu'))
10
 
11
+ cls_layer = get_cls_layer()
12
+
13
+ beto_model = BertModel.from_pretrained("furrutiav/beto_coherence", revision="df96f50cfb1e3f7923912a25b1c3a865116fae4a")
14
+
15
+ beto_tokenizer = BertTokenizer.from_pretrained("furrutiav/beto_coherence", revision="df96f50cfb1e3f7923912a25b1c3a865116fae4a", do_lower_case=False)
16
+
17
+ e = beto_model.eval()
18
+
19
+ def preproccesing(Q, A, maxlen=60):
20
+ Q = " ".join(str(Q).replace("\n", " ").split())
21
+ A = " ".join(str(A).replace("\n", " ").split())
22
+ Q = Q if Q != "" else "nan"
23
+ A = A if A != "" else "nan"
24
+
25
+ tokens1 = beto_tokenizer.tokenize(Q)
26
+ tokens1 = ['[CLS]'] + tokens1 + ['[SEP]']
27
+ if len(tokens1) < maxlen:
28
+ tokens1 = tokens1 + ['[PAD]' for _ in range(maxlen - len(tokens1))]
29
+ else:
30
+ tokens1 = tokens1[:maxlen-1] + ['[SEP]']
31
+
32
+ tokens2 = beto_tokenizer.tokenize(A)
33
+ tokens2 = tokens2 + ['[SEP]']
34
+ if len(tokens2) < maxlen:
35
+ tokens2 = tokens2 + ['[PAD]' for _ in range(maxlen - len(tokens2))]
36
+ else:
37
+ tokens2 = tokens2[:maxlen-1] + ['[SEP]']
38
+
39
+ tokens = tokens1+tokens2
40
+ tokens_ids = beto_tokenizer.convert_tokens_to_ids(tokens)
41
+ tokens_ids_tensor = torch.tensor(tokens_ids)
42
+
43
+ attn_mask = (tokens_ids_tensor != 1).long()
44
+ return tokens_ids_tensor, attn_mask
45
+
46
+ def C1Classifier(Q, A, is_probs=True):
47
+ tokens_ids_tensor, attn_mask = preproccesing(Q, A)
48
+ cont_reps = beto_model(tokens_ids_tensor.unsqueeze(0), attention_mask = attn_mask.unsqueeze(0))
49
+ cls_rep = cont_reps.last_hidden_state[:, 0]
50
+ logits = cls_layer(cls_rep)
51
+ probs = torch.sigmoid(logits)
52
+ soft_probs = probs.argmax(1)
53
+ if is_probs:
54
+ return probs.detach().numpy()[0]
55
+ else:
56
+ return soft_probs.numpy()[0]
57
+
58
+
59
+
60
+