File size: 2,765 Bytes
dda1539
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import torch

def log_prob_to_prob(log_probs, temp=1):
    """
    Convert log probabilities to probability distribution and normalize.
    Args:
        log_probs (torch.Tensor): Log probs (n_prompts, n_drafts, vocab_size)
    Returns:
        Probability distribution (n_prompts, n_drafts, vocab_size)
    """
    # stability constant
    log_probs = log_probs + torch.max(log_probs, dim=-1, keepdim=True)[0]
    probs = torch.softmax(log_probs / temp, dim=-1)
    return probs

def decode(tokenizer, encoding):
    """
    Decode a list of tokens to a string
    Args:
        tokenizer (Any): Tokenizer
        encoding (torch.Tensor): Encoding
    Returns:
        decoding (str)
    """
    pad_locs = (encoding == -1).nonzero()
    if len(pad_locs > 0):
        encoding = encoding[:pad_locs[0].item()]
    return tokenizer.decode(encoding.to(torch.int32).tolist())

def print_gen(gens, logprobs, tokenizer, n_drafts, prompt_len, output_file):
    """
    Print out generations for debugging.
    Args:
        gens (n_prompts * n_drafts, seq_len): Generations to print
        logprobs (n_prompts * n_drafts): Log probs of each generation
        tokenizer (any): Tokenizer
        n_drafts (int): Number of drafts per prompt
        prompt_len (int): Number of tokens in prompt
    """
    n_prompts, n_drafts, seq_len = gens.shape
    gens = gens.reshape(-1, seq_len)
    logprobs = logprobs.flatten()
    count = 0
    for i in range(len(gens)):
        d = decode(tokenizer, gens[i])
        # first draft of this prompt
        if i % n_drafts == 0:
            count = 0
            print("---------------", file=output_file)
            prompt = decode(tokenizer, gens[i][:prompt_len])
            print(f"prompt: {prompt}", file=output_file)
        print(f"logprob: {logprobs[i]} {count}: {d}", file=output_file)
        count += 1
        
def print_probs(next_probs, tokenizer, output_file):
    """
    Print out next token options and probabilities for debugging
    Args:
        next_probs (torch.Tensor): Next token probabilities (n_prompts, n_drafts, vocab_size)
        tokenizer (any): Tokenizer
    """
    print("\tReminder: At most first n_drafts from seq can be selected.", file=output_file)
    n_prompts, n_drafts, vocab_size = next_probs.shape
    for p_idx in range(n_prompts):
        print(f"\tPrompt {p_idx}:", file=output_file)
        for d_idx in range(n_drafts):
            next_token_probs, next_token_idx = next_probs[p_idx, d_idx].topk(n_drafts+2, dim=-1)
            print(f"\t\tTokens: {[tokenizer.decode([i.item()]) for i in next_token_idx]}", file=output_file)
            print(f"\t\tLog Probs: {torch.log(next_token_probs)}", file=output_file)
            print(f"\t\tProbs: {next_token_probs}", file=output_file)