Spaces:
Running
on
Zero
Running
on
Zero
| # Copyright (c) 2023, Albert Gu, Tri Dao. | |
| import gc | |
| import time | |
| from collections import namedtuple | |
| from dataclasses import dataclass, field | |
| from functools import partial | |
| from typing import Callable, Optional, Sequence, Union | |
| import torch | |
| import torch.nn.functional as F | |
| from einops import rearrange, repeat | |
| from torch import Tensor | |
| from torch.profiler import ProfilerActivity, profile, record_function | |
| from transformers.generation import GreedySearchDecoderOnlyOutput, SampleDecoderOnlyOutput, TextStreamer | |
| class InferenceParams: | |
| """Inference parameters that are passed to the main model in order | |
| to efficienly calculate and store the context during inference.""" | |
| max_seqlen: int | |
| max_batch_size: int | |
| seqlen_offset: int = 0 | |
| batch_size_offset: int = 0 | |
| key_value_memory_dict: dict = field(default_factory=dict) | |
| lengths_per_sample: Optional[Tensor] = None | |
| def reset(self, max_seqlen, max_batch_size): | |
| self.max_seqlen = max_seqlen | |
| self.max_batch_size = max_batch_size | |
| self.seqlen_offset = 0 | |
| if self.lengths_per_sample is not None: | |
| self.lengths_per_sample.zero_() | |
| def modify_logits_for_min_p_filtering(logits, min_p): | |
| """Set the logits for none min_p values to -inf. Done in-place.""" | |
| if min_p <= 0.0 or min_p >= 1.0: | |
| return | |
| indices_to_remove = logits < min_p | |
| logits.masked_fill_(indices_to_remove, float("-Inf")) | |
| # https://github.com/NVIDIA/Megatron-LM/blob/0bb597b42c53355a567aba2a1357cc34b9d99ddd/megatron/text_generation/sampling.py | |
| # https://github.com/huggingface/transformers/blob/a44985b41cfa2de48a5e1de7f1f93b7483da25d1/src/transformers/generation/logits_process.py#L231 | |
| def modify_logits_for_top_k_filtering(logits, top_k): | |
| """Set the logits for none top-k values to -inf. Done in-place.""" | |
| indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] | |
| logits.masked_fill_(indices_to_remove, float("-Inf")) | |
| # https://github.com/NVIDIA/Megatron-LM/blob/0bb597b42c53355a567aba2a1357cc34b9d99ddd/megatron/text_generation/sampling.py | |
| # https://github.com/huggingface/transformers/blob/a44985b41cfa2de48a5e1de7f1f93b7483da25d1/src/transformers/generation/logits_process.py#L170 | |
| def modify_logits_for_top_p_filtering(logits, top_p): | |
| """Set the logits for none top-p values to -inf. Done in-place.""" | |
| if top_p <= 0.0 or top_p >= 1.0: | |
| return | |
| # First sort and calculate cumulative sum of probabilities. | |
| sorted_logits, sorted_indices = torch.sort(logits, descending=False) | |
| cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1) | |
| # Remove tokens with cumulative top_p above the threshold (token with 0 are kept) | |
| sorted_indices_to_remove = cumulative_probs <= (1 - top_p) | |
| # scatter sorted tensors to original indexing | |
| indices_to_remove = sorted_indices_to_remove.scatter( | |
| 1, sorted_indices, sorted_indices_to_remove | |
| ) | |
| logits.masked_fill_(indices_to_remove, float("-inf")) | |
| def modify_logit_for_repetition_penalty(logits, prev_output_tokens, repetition_penalty=1.0): | |
| """Apply repetition penalty. See https://arxiv.org/abs/1909.05858 | |
| logits: (batch_size, vocab_size) | |
| prev_output_tokens: (batch_size, seq_len) | |
| """ | |
| if repetition_penalty == 1.0: | |
| return logits | |
| score = torch.gather(logits, 1, prev_output_tokens) | |
| # if score < 0 then repetition penalty has to be multiplied to reduce the previous token probability | |
| score = torch.where(score < 0, score * repetition_penalty, score / repetition_penalty) | |
| logits.scatter_(1, prev_output_tokens, score) | |
| return logits | |
| def sample(logits, top_k=1, top_p=0.0, min_p=0.0, temperature=1.0): | |
| """Sample from top-k logits. | |
| Arguments: | |
| logits: Tensor of shape (batch_size, vocab_size) | |
| """ | |
| if top_k == 1: # Short-circuit for greedy decoding | |
| return logits.argmax(dim=-1) | |
| else: | |
| if top_p > 0.0: | |
| assert top_p <= 1.0, "top-p should be in (0, 1]." | |
| if top_k > 0: | |
| top_k = min(top_k, logits.size(-1)) # Safety check | |
| logits_top, indices = torch.topk(logits, top_k, dim=-1) | |
| if temperature != 1.0: | |
| logits_top /= temperature | |
| modify_logits_for_top_p_filtering(logits_top, top_p) | |
| return indices[ | |
| torch.arange(indices.shape[0], device=indices.device), | |
| torch.multinomial(torch.softmax(logits_top, dim=-1), num_samples=1).squeeze(dim=-1), | |
| ] | |
| else: | |
| if min_p > 0.0: | |
| logits_top = logits.clone() | |
| max_prob = logits_top[..., 0].item() | |
| min_prob = max_prob * min_p | |
| modify_logits_for_min_p_filtering(logits_top, min_p) | |
| if temperature != 1.0: | |
| logits_top /= temperature | |
| return torch.multinomial(torch.softmax(logits_top, dim=-1), num_samples=1).squeeze(dim=-1) | |
| # Clone so that when we modify for top_p we don't change the original logits | |
| logits_top = logits / temperature if temperature != 1.0 else logits.clone() | |
| modify_logits_for_top_p_filtering(logits_top, top_p) | |
| return torch.multinomial(torch.softmax(logits_top, dim=-1), num_samples=1).squeeze( | |
| dim=-1 | |
| ) | |
| def decode( | |
| input_ids, | |
| model, | |
| max_length, | |
| top_k=1, | |
| top_p=0.0, | |
| min_p=0.0, | |
| temperature=1.0, | |
| repetition_penalty=1.0, | |
| eos_token_id=None, | |
| teacher_outputs=None, | |
| vocab_size=None, | |
| cg=False, | |
| enable_timing=False, | |
| streamer: Optional[TextStreamer] = None | |
| ): | |
| """Decoding, either greedy or with top-k or top-p sampling. | |
| If top-k = 0, don't limit the number of candidates (pure sampling). | |
| Top-k and top-p can be used together. If top_k > 0 and top_p > 0, then top-k is applied first, | |
| then top-p. | |
| We assume that all sequences in the same batch have the same length. | |
| Arguments: | |
| input_ids: (batch, seq_len) | |
| max_length: int | |
| teacher_outputs (optional): (batch, seq_len). If provided, instead of sampling from the | |
| logits, the next token is taken from the teacher_outputs. Useful for testing. | |
| Returns: GreedySearchDecoderOnlyOutput or SampleDecoderOnlyOutput, with the following fields: | |
| sequences: (batch, max_length) | |
| scores: tuples of (batch, vocab_size) | |
| """ | |
| if streamer is not None: | |
| streamer.put(input_ids.cpu()) | |
| batch_size, seqlen_og = input_ids.shape | |
| teacher_output_len = teacher_outputs.shape[1] if teacher_outputs is not None else 0 | |
| if cg: | |
| if not hasattr(model, "_decoding_cache"): | |
| model._decoding_cache = None | |
| model._decoding_cache = update_graph_cache( | |
| model, | |
| model._decoding_cache, | |
| batch_size, | |
| seqlen_og, | |
| max_length, | |
| ) | |
| inference_params = model._decoding_cache.inference_params | |
| inference_params.reset(max_length, batch_size) | |
| else: | |
| inference_params = InferenceParams(max_seqlen=max_length, max_batch_size=batch_size) | |
| def get_logits(input_ids, inference_params): | |
| decoding = inference_params.seqlen_offset > 0 | |
| if decoding: | |
| position_ids = torch.full( | |
| (batch_size, 1), | |
| inference_params.seqlen_offset, | |
| dtype=torch.long, | |
| device=input_ids.device, | |
| ) | |
| else: | |
| position_ids = None | |
| if not cg or not decoding: | |
| logits = model( | |
| input_ids, | |
| position_ids=position_ids, | |
| inference_params=inference_params, | |
| num_last_tokens=1, | |
| ).logits.squeeze(dim=1) | |
| else: | |
| logits = model._decoding_cache.run( | |
| input_ids, position_ids, inference_params.seqlen_offset | |
| ).squeeze(dim=1) | |
| return logits[..., :vocab_size] if vocab_size is not None else logits | |
| def sample_tokens(logits, inference_params): | |
| if teacher_outputs is None or teacher_output_len <= inference_params.seqlen_offset: | |
| token = sample(logits, top_k=top_k, top_p=top_p, min_p=min_p, temperature=temperature) | |
| else: | |
| token = teacher_outputs[:, inference_params.seqlen_offset] | |
| # return rearrange(token, "b -> b 1") | |
| return token.unsqueeze(1) | |
| def should_stop(current_token, inference_params): | |
| if inference_params.seqlen_offset == 0: | |
| return False | |
| if eos_token_id is not None and (current_token == eos_token_id).all(): | |
| return True | |
| if inference_params.seqlen_offset >= max_length - 1: | |
| return True | |
| return False | |
| start = torch.cuda.Event(enable_timing=enable_timing) | |
| end = torch.cuda.Event(enable_timing=enable_timing) | |
| if enable_timing: | |
| start.record() | |
| scores, sequences = [], [input_ids] | |
| sequences_cat = input_ids | |
| while not should_stop(sequences[-1], inference_params): | |
| scores.append(get_logits(sequences[-1], inference_params)) | |
| inference_params.seqlen_offset += sequences[-1].shape[1] | |
| if repetition_penalty == 1.0: | |
| sampled_tokens = sample_tokens(scores[-1], inference_params) | |
| else: | |
| logits = modify_logit_for_repetition_penalty( | |
| scores[-1].clone(), sequences_cat, repetition_penalty | |
| ) | |
| sampled_tokens = sample_tokens(logits, inference_params) | |
| sequences_cat = torch.cat([sequences_cat, sampled_tokens], dim=1) | |
| sequences.append(sampled_tokens) | |
| if streamer is not None: | |
| streamer.put(sampled_tokens.cpu()) | |
| if streamer is not None: | |
| streamer.end() | |
| if enable_timing: | |
| end.record() | |
| torch.cuda.synchronize() | |
| print(f"Prompt processing + decoding time: {(start.elapsed_time(end)):.0f}ms") | |
| output_cls = GreedySearchDecoderOnlyOutput if top_k == 1 else SampleDecoderOnlyOutput | |
| return output_cls(sequences=torch.cat(sequences, dim=1), scores=tuple(scores)) | |
| class GenerationMixin: | |
| def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs): | |
| raise NotImplementedError | |
| def generate( | |
| self, | |
| input_ids, | |
| max_length, | |
| top_k=1, | |
| top_p=0.0, | |
| min_p=0.0, | |
| temperature=1.0, | |
| return_dict_in_generate=False, | |
| output_scores=False, | |
| **kwargs, | |
| ): | |
| output = decode( | |
| input_ids, self, max_length, top_k=top_k, top_p=top_p, min_p = min_p, temperature=temperature, **kwargs | |
| ) | |
| if not output_scores: | |
| output.scores = None | |
| return output if return_dict_in_generate else output.sequences | |
| class DecodingCGCache: | |
| max_batch_size: int = 0 | |
| max_seqlen: int = 0 | |
| device = None | |
| dtype = None | |
| callables: dict = field(default_factory=dict) | |
| mempool = None | |
| inference_params: Optional[InferenceParams] = None | |
| run: Optional[Callable] = None | |
| def update_graph_cache( | |
| model, | |
| cache, | |
| batch_size, | |
| seqlen_og, | |
| max_seqlen, | |
| decoding_seqlens=(1,), | |
| dtype=None, | |
| n_warmups=2, | |
| ): | |
| if cache is None: | |
| cache = DecodingCGCache() | |
| param_example = next(iter(model.parameters())) | |
| device = param_example.device | |
| if dtype is None: | |
| dtype = param_example.dtype | |
| if ( | |
| (device, dtype) != (cache.device, cache.dtype) | |
| or batch_size > cache.max_batch_size | |
| or max_seqlen > cache.max_seqlen | |
| ): # Invalidate the cache | |
| cache.callables = {} | |
| cache.mempool = None | |
| cache.inference_params = None | |
| gc.collect() | |
| cache.device, cache.dtype = device, dtype | |
| cache.max_batch_size, cache.max_seqlen = batch_size, max_seqlen | |
| assert hasattr(model, "allocate_inference_cache"), "CUDA graph decoding requires that the model has a method allocate_inference_cache" | |
| inf_cache = model.allocate_inference_cache(batch_size, max_seqlen, dtype) | |
| lengths_per_sample = torch.full((batch_size,), seqlen_og, dtype=torch.int32, device=device) | |
| cache.inference_params = InferenceParams( | |
| max_seqlen=max_seqlen, | |
| max_batch_size=batch_size, | |
| seqlen_offset=seqlen_og, | |
| key_value_memory_dict=inf_cache, | |
| lengths_per_sample=lengths_per_sample, | |
| ) | |
| cache.mempool = torch.cuda.graphs.graph_pool_handle() | |
| for decoding_seqlen in decoding_seqlens: | |
| if (batch_size, decoding_seqlen) not in cache.callables: | |
| cache.callables[batch_size, decoding_seqlen] = capture_graph( | |
| model, | |
| cache.inference_params, | |
| batch_size, | |
| max_seqlen, | |
| decoding_seqlen=decoding_seqlen, | |
| mempool=cache.mempool, | |
| n_warmups=n_warmups, | |
| ) | |
| def dispatch(input_ids, position_ids, seqlen): | |
| batch_size, decoding_seqlen = input_ids.shape[:2] | |
| return cache.callables[batch_size, decoding_seqlen](input_ids, position_ids, seqlen) | |
| cache.run = dispatch | |
| cache.inference_params.seqlen_offset = 0 # Reset so it's not confusing | |
| return cache | |
| def capture_graph( | |
| model, inference_params, batch_size, max_seqlen, decoding_seqlen=1, mempool=None, n_warmups=2 | |
| ): | |
| device = next(iter(model.parameters())).device | |
| input_ids = torch.full((batch_size, decoding_seqlen), 0, dtype=torch.long, device=device) | |
| position_ids = torch.full((batch_size, decoding_seqlen), 0, dtype=torch.long, device=device) | |
| seqlen_offset_og = inference_params.seqlen_offset | |
| inference_params.seqlen_offset = max_seqlen - decoding_seqlen | |
| inference_params.lengths_per_sample[:] = inference_params.seqlen_offset | |
| # Warmup before capture | |
| s = torch.cuda.Stream() | |
| s.wait_stream(torch.cuda.current_stream()) | |
| with torch.cuda.stream(s): | |
| for _ in range(n_warmups): | |
| logits = model( | |
| input_ids, | |
| position_ids=position_ids, | |
| inference_params=inference_params, | |
| num_last_tokens=decoding_seqlen, | |
| ).logits | |
| s.synchronize() | |
| # This might be needed for correctness if we run with NCCL_GRAPH_MIXING_SUPPORT=0, | |
| # which requires that graph launch and non-captured launch to not overlap (I think, | |
| # that's how I interpret the documentation). I'm not sure if this is required. | |
| if torch.distributed.is_initialized(): | |
| torch.distributed.barrier() | |
| torch.cuda.current_stream().wait_stream(s) | |
| # Captures the graph | |
| # To allow capture, automatically sets a side stream as the current stream in the context | |
| graph = torch.cuda.CUDAGraph() | |
| with torch.cuda.graph(graph, pool=mempool): | |
| logits = model( | |
| input_ids, | |
| position_ids=position_ids, | |
| inference_params=inference_params, | |
| num_last_tokens=decoding_seqlen, | |
| ).logits | |
| def run(new_input_ids, new_position_ids, seqlen): | |
| inference_params.lengths_per_sample[:] = seqlen | |
| input_ids.copy_(new_input_ids) | |
| position_ids.copy_(new_position_ids) | |
| graph.replay() | |
| return logits.clone() | |
| inference_params.seqlen_offset = seqlen_offset_og | |
| return run | |