Jinkin commited on
Commit
cf5524e
1 Parent(s): f476313

Upload utils.py

Browse files
Files changed (1) hide show
  1. eval/utils.py +56 -0
eval/utils.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import logging
3
+
4
+ from torch import Tensor
5
+ from typing import Mapping
6
+
7
+
8
+ def _setup_logger():
9
+ log_format = logging.Formatter("[%(asctime)s %(levelname)s] %(message)s")
10
+ logger = logging.getLogger()
11
+ logger.setLevel(logging.INFO)
12
+
13
+ console_handler = logging.StreamHandler()
14
+ console_handler.setFormatter(log_format)
15
+ logger.handlers = [console_handler]
16
+
17
+ return logger
18
+
19
+
20
+ logger = _setup_logger()
21
+
22
+
23
+ def move_to_cuda(sample):
24
+ if len(sample) == 0:
25
+ return {}
26
+
27
+ def _move_to_cuda(maybe_tensor):
28
+ if torch.is_tensor(maybe_tensor):
29
+ return maybe_tensor.cuda(non_blocking=True)
30
+ elif isinstance(maybe_tensor, dict):
31
+ return {key: _move_to_cuda(value) for key, value in maybe_tensor.items()}
32
+ elif isinstance(maybe_tensor, list):
33
+ return [_move_to_cuda(x) for x in maybe_tensor]
34
+ elif isinstance(maybe_tensor, tuple):
35
+ return tuple([_move_to_cuda(x) for x in maybe_tensor])
36
+ elif isinstance(maybe_tensor, Mapping):
37
+ return type(maybe_tensor)({k: _move_to_cuda(v) for k, v in maybe_tensor.items()})
38
+ else:
39
+ return maybe_tensor
40
+
41
+ return _move_to_cuda(sample)
42
+
43
+
44
+ def pool(last_hidden_states: Tensor,
45
+ attention_mask: Tensor,
46
+ pool_type: str) -> Tensor:
47
+ last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0)
48
+
49
+ if pool_type == "avg":
50
+ emb = last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]
51
+ elif pool_type == "cls":
52
+ emb = last_hidden[:, 0]
53
+ else:
54
+ raise ValueError(f"pool_type {pool_type} not supported")
55
+
56
+ return emb