import torch def log_prob_to_prob(log_probs, temp=1): """ Convert log probabilities to probability distribution and normalize. Args: log_probs (torch.Tensor): Log probs (n_prompts, n_drafts, vocab_size) Returns: Probability distribution (n_prompts, n_drafts, vocab_size) """ # stability constant log_probs = log_probs + torch.max(log_probs, dim=-1, keepdim=True)[0] probs = torch.softmax(log_probs / temp, dim=-1) return probs def decode(tokenizer, encoding): """ Decode a list of tokens to a string Args: tokenizer (Any): Tokenizer encoding (torch.Tensor): Encoding Returns: decoding (str) """ pad_locs = (encoding == -1).nonzero() if len(pad_locs > 0): encoding = encoding[:pad_locs[0].item()] return tokenizer.decode(encoding.to(torch.int32).tolist()) def print_gen(gens, logprobs, tokenizer, n_drafts, prompt_len, output_file): """ Print out generations for debugging. Args: gens (n_prompts * n_drafts, seq_len): Generations to print logprobs (n_prompts * n_drafts): Log probs of each generation tokenizer (any): Tokenizer n_drafts (int): Number of drafts per prompt prompt_len (int): Number of tokens in prompt """ n_prompts, n_drafts, seq_len = gens.shape gens = gens.reshape(-1, seq_len) logprobs = logprobs.flatten() count = 0 for i in range(len(gens)): d = decode(tokenizer, gens[i]) # first draft of this prompt if i % n_drafts == 0: count = 0 print("---------------", file=output_file) prompt = decode(tokenizer, gens[i][:prompt_len]) print(f"prompt: {prompt}", file=output_file) print(f"logprob: {logprobs[i]} {count}: {d}", file=output_file) count += 1 def print_probs(next_probs, tokenizer, output_file): """ Print out next token options and probabilities for debugging Args: next_probs (torch.Tensor): Next token probabilities (n_prompts, n_drafts, vocab_size) tokenizer (any): Tokenizer """ print("\tReminder: At most first n_drafts from seq can be selected.", file=output_file) n_prompts, n_drafts, vocab_size = next_probs.shape for p_idx in range(n_prompts): print(f"\tPrompt {p_idx}:", file=output_file) for d_idx in range(n_drafts): next_token_probs, next_token_idx = next_probs[p_idx, d_idx].topk(n_drafts+2, dim=-1) print(f"\t\tTokens: {[tokenizer.decode([i.item()]) for i in next_token_idx]}", file=output_file) print(f"\t\tLog Probs: {torch.log(next_token_probs)}", file=output_file) print(f"\t\tProbs: {next_token_probs}", file=output_file)