File size: 622 Bytes
0bf81ba
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
import torch


NEGATIVE_INF = -100000.0
HALF_NEGATIVE_INF = -60000.0  # half precision


def get_first_sentence(txt, min_len=5):
    eos = '<|endoftext|>'
    eos_idx = txt.find(eos)
    if eos_idx > 0:
        txt = txt[eos_idx:]
    txt = txt.replace('\n', ' ')
    sents = txt.split('. ')
    if len(sents[0]) >= min_len:
        sent = f'{sents[0].strip()}.'
    else:
        sent = txt
    return sent


def logits_to_entropy(logits):
    distribution = torch.distributions.Categorical(logits=logits)
    return distribution.entropy()


def mask_pad(value, mask):
    return value * mask + NEGATIVE_INF * (1 - mask)