Spaces:
Runtime error
Runtime error
import torch | |
def log_prob_to_prob(log_probs, temp=1): | |
""" | |
Convert log probabilities to probability distribution and normalize. | |
Args: | |
log_probs (torch.Tensor): Log probs (n_prompts, n_drafts, vocab_size) | |
Returns: | |
Probability distribution (n_prompts, n_drafts, vocab_size) | |
""" | |
# stability constant | |
log_probs = log_probs + torch.max(log_probs, dim=-1, keepdim=True)[0] | |
probs = torch.softmax(log_probs / temp, dim=-1) | |
return probs | |
def decode(tokenizer, encoding): | |
""" | |
Decode a list of tokens to a string | |
Args: | |
tokenizer (Any): Tokenizer | |
encoding (torch.Tensor): Encoding | |
Returns: | |
decoding (str) | |
""" | |
pad_locs = (encoding == -1).nonzero() | |
if len(pad_locs > 0): | |
encoding = encoding[:pad_locs[0].item()] | |
return tokenizer.decode(encoding.to(torch.int32).tolist()) | |
def print_gen(gens, logprobs, tokenizer, n_drafts, prompt_len, output_file): | |
""" | |
Print out generations for debugging. | |
Args: | |
gens (n_prompts * n_drafts, seq_len): Generations to print | |
logprobs (n_prompts * n_drafts): Log probs of each generation | |
tokenizer (any): Tokenizer | |
n_drafts (int): Number of drafts per prompt | |
prompt_len (int): Number of tokens in prompt | |
""" | |
n_prompts, n_drafts, seq_len = gens.shape | |
gens = gens.reshape(-1, seq_len) | |
logprobs = logprobs.flatten() | |
count = 0 | |
for i in range(len(gens)): | |
d = decode(tokenizer, gens[i]) | |
# first draft of this prompt | |
if i % n_drafts == 0: | |
count = 0 | |
print("---------------", file=output_file) | |
prompt = decode(tokenizer, gens[i][:prompt_len]) | |
print(f"prompt: {prompt}", file=output_file) | |
print(f"logprob: {logprobs[i]} {count}: {d}", file=output_file) | |
count += 1 | |
def print_probs(next_probs, tokenizer, output_file): | |
""" | |
Print out next token options and probabilities for debugging | |
Args: | |
next_probs (torch.Tensor): Next token probabilities (n_prompts, n_drafts, vocab_size) | |
tokenizer (any): Tokenizer | |
""" | |
print("\tReminder: At most first n_drafts from seq can be selected.", file=output_file) | |
n_prompts, n_drafts, vocab_size = next_probs.shape | |
for p_idx in range(n_prompts): | |
print(f"\tPrompt {p_idx}:", file=output_file) | |
for d_idx in range(n_drafts): | |
next_token_probs, next_token_idx = next_probs[p_idx, d_idx].topk(n_drafts+2, dim=-1) | |
print(f"\t\tTokens: {[tokenizer.decode([i.item()]) for i in next_token_idx]}", file=output_file) | |
print(f"\t\tLog Probs: {torch.log(next_token_probs)}", file=output_file) | |
print(f"\t\tProbs: {next_token_probs}", file=output_file) |