|
import torch |
|
import math |
|
import random |
|
|
|
|
|
|
|
def log(t, eps = 1e-20): |
|
return torch.log(t.clamp(min = eps)) |
|
|
|
def gumbel_noise(t): |
|
noise = torch.zeros_like(t).uniform_(0, 1) |
|
return -log(-log(noise)) |
|
|
|
def gumbel_sample(t, temperature = 1., dim = -1, keepdim = True): |
|
return ((t / max(temperature, 1e-10)) + gumbel_noise(t)).argmax(dim = dim, keepdim = keepdim) |
|
|
|
def top_k(logits, thres = 0.9): |
|
k = math.ceil((1 - thres) * logits.shape[-1]) |
|
val, ind = torch.topk(logits, k) |
|
probs = torch.full_like(logits, float('-inf')) |
|
probs.scatter_(-1, ind, val) |
|
return probs |
|
|
|
|
|
|
|
def generate_text(model, tokenizer, prompt: torch.Tensor, seq_len: int): |
|
prompt_seq_len = prompt.shape[-1] |
|
|
|
h_states = None |
|
logits = None |
|
text = "" |
|
|
|
for i in range(prompt_seq_len): |
|
tok = prompt[:, i:i+1] |
|
logits, h_states = model.step(tok, h_states) |
|
|
|
for _ in range(seq_len): |
|
logits = top_k(logits, thres=.9) |
|
token = gumbel_sample(logits, temperature=.7, dim=-1)[0] |
|
|
|
logits, h_states = model.step(token, h_states) |
|
|
|
token = tokenizer.decode(token.item()) |
|
text += token |
|
|
|
return text |
|
|
|
|
|
|
|
def generate_name(): |
|
prefix = "mingru" |
|
random_number = random.randint(0, 0xFFFF) |
|
hex_code = f"{random_number:04x}" |
|
unique_name = f"{prefix}-{hex_code}" |
|
return unique_name |