mingru.flop / util.py
flpelerin's picture
Update util.py
f4a4361 verified
import torch
import math
import random
def log(t, eps = 1e-20):
return torch.log(t.clamp(min = eps))
def gumbel_noise(t):
noise = torch.zeros_like(t).uniform_(0, 1)
return -log(-log(noise))
def gumbel_sample(t, temperature = 1., dim = -1, keepdim = True):
return ((t / max(temperature, 1e-10)) + gumbel_noise(t)).argmax(dim = dim, keepdim = keepdim)
def top_k(logits, thres = 0.9):
k = math.ceil((1 - thres) * logits.shape[-1])
val, ind = torch.topk(logits, k)
probs = torch.full_like(logits, float('-inf'))
probs.scatter_(-1, ind, val)
return probs
def generate_text(model, tokenizer, prompt: torch.Tensor, seq_len: int):
prompt_seq_len = prompt.shape[-1]
h_states = None
logits = None
text = ""
for i in range(prompt_seq_len):
tok = prompt[:, i:i+1] # (1, 1)
logits, h_states = model.step(tok, h_states)
for _ in range(seq_len):
logits = top_k(logits, thres=.9)
token = gumbel_sample(logits, temperature=.7, dim=-1)[0]
logits, h_states = model.step(token, h_states)
token = tokenizer.decode(token.item())
text += token
return text
def generate_name():
prefix = "mingru"
random_number = random.randint(0, 0xFFFF)
hex_code = f"{random_number:04x}"
unique_name = f"{prefix}-{hex_code}"
return unique_name