import torch NEGATIVE_INF = -100000.0 HALF_NEGATIVE_INF = -60000.0 # half precision def get_first_sentence(txt, min_len=5): eos = '<|endoftext|>' eos_idx = txt.find(eos) if eos_idx > 0: txt = txt[eos_idx:] txt = txt.replace('\n', ' ') sents = txt.split('. ') if len(sents[0]) >= min_len: sent = f'{sents[0].strip()}.' else: sent = txt return sent def logits_to_entropy(logits): distribution = torch.distributions.Categorical(logits=logits) return distribution.entropy() def mask_pad(value, mask): return value * mask + NEGATIVE_INF * (1 - mask)