Spaces:
Runtime error
Runtime error
| import os | |
| import queue | |
| import threading | |
| import time | |
| from pathlib import Path | |
| from typing import Optional, Tuple, Union | |
| import click | |
| import hydra | |
| import numpy as np | |
| import torch | |
| import torch._dynamo.config | |
| import torch._inductor.config | |
| from hydra import compose, initialize | |
| from hydra.utils import instantiate | |
| from loguru import logger | |
| from tqdm import tqdm | |
| from transformers import AutoTokenizer | |
| from fish_speech.datasets.text import CODEBOOK_EOS_TOKEN_ID, CODEBOOK_PAD_TOKEN_ID | |
| from fish_speech.text.clean import clean_text | |
| os.environ["TOKENIZERS_PARALLELISM"] = "false" | |
| torch._inductor.config.coordinate_descent_tuning = True | |
| torch._inductor.config.triton.unique_kernel_names = True | |
| if hasattr(torch._inductor.config, "fx_graph_cache"): | |
| # Experimental feature to reduce compilation times, will be on by default in future | |
| torch._inductor.config.fx_graph_cache = True | |
| from fish_speech.models.text2semantic.llama import DualARTransformer, NaiveTransformer | |
| def multinomial_sample_one_no_sync( | |
| probs_sort, | |
| ): # Does multinomial sampling without a cuda synchronization | |
| q = torch.empty_like(probs_sort).exponential_(1) | |
| return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int) | |
| def logits_to_probs( | |
| logits, | |
| previous_tokens: Optional[torch.Tensor] = None, | |
| temperature: float = 1.0, | |
| top_k: Optional[int] = None, | |
| top_p: Optional[int] = None, | |
| repetition_penalty: float = 1.0, | |
| ): | |
| if previous_tokens is not None and repetition_penalty != 1.0: | |
| previous_tokens = previous_tokens.long() | |
| score = torch.gather(logits, dim=0, index=previous_tokens) | |
| score = torch.where( | |
| score < 0, score * repetition_penalty, score / repetition_penalty | |
| ) | |
| logits.scatter_(dim=0, index=previous_tokens, src=score) | |
| if top_p is not None and top_p < 1.0: | |
| sorted_logits, sorted_indices = torch.sort(logits, descending=True) | |
| cum_probs = torch.cumsum( | |
| torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1 | |
| ) | |
| sorted_indices_to_remove = cum_probs > top_p | |
| sorted_indices_to_remove[0] = False # keep at least one option | |
| indices_to_remove = sorted_indices_to_remove.scatter( | |
| dim=0, index=sorted_indices, src=sorted_indices_to_remove | |
| ) | |
| logits = logits.masked_fill(indices_to_remove, -float("Inf")) | |
| logits = logits / max(temperature, 1e-5) | |
| if top_k is not None: | |
| v, _ = torch.topk(logits, min(top_k, logits.size(-1))) | |
| pivot = v.select(-1, -1).unsqueeze(-1) | |
| logits = torch.where(logits < pivot, -float("Inf"), logits) | |
| probs = torch.nn.functional.softmax(logits, dim=-1) | |
| return probs | |
| def sample( | |
| logits, | |
| previous_tokens: Optional[torch.Tensor] = None, | |
| **sampling_kwargs, | |
| ) -> Tuple[torch.Tensor, torch.Tensor]: | |
| probs = logits_to_probs( | |
| logits=logits[0, -1], previous_tokens=previous_tokens, **sampling_kwargs | |
| ) | |
| idx_next = multinomial_sample_one_no_sync(probs) | |
| return idx_next, probs | |
| def decode_one_token_ar( | |
| model: DualARTransformer, | |
| x: torch.Tensor, | |
| input_pos: torch.Tensor, | |
| previous_tokens: torch.Tensor = None, | |
| **sampling_kwargs, | |
| ) -> torch.Tensor: | |
| x = model.forward_generate(x, input_pos) | |
| codebooks = [ | |
| sample( | |
| x.logits, | |
| previous_tokens=None, # Disable repetition penalty for the token codebook | |
| **sampling_kwargs, | |
| )[0] | |
| ] | |
| x = x.hidden_states | |
| # Cleanup the cache | |
| for layer in model.fast_layers: | |
| layer.attention.kv_cache.k_cache.fill_(0) | |
| layer.attention.kv_cache.v_cache.fill_(0) | |
| for codebook_idx in range(model.config.num_codebooks): | |
| input_pos = torch.tensor([codebook_idx], device=x.device, dtype=torch.long) | |
| logits = model.forward_generate_fast(x, input_pos) | |
| a = sample( | |
| logits, | |
| previous_tokens=( | |
| previous_tokens[codebook_idx + 1] | |
| if previous_tokens is not None | |
| else None | |
| ), | |
| **sampling_kwargs, | |
| )[0] | |
| x = model.fast_embeddings(a) | |
| codebooks.append(a) | |
| return torch.stack(codebooks, dim=0) | |
| def decode_one_token_naive( | |
| model: NaiveTransformer, | |
| x: torch.Tensor, | |
| input_pos: torch.Tensor, | |
| previous_tokens: torch.Tensor = None, | |
| **sampling_kwargs, | |
| ) -> torch.Tensor: | |
| x = model.forward_generate(x, input_pos) | |
| codebooks = [ | |
| sample( | |
| x.token_logits, | |
| previous_tokens=None, # Disable repetition penalty for the token codebook | |
| **sampling_kwargs, | |
| )[0] | |
| ] | |
| for i in range(model.config.num_codebooks): | |
| codebooks.append( | |
| sample( | |
| x.codebook_logits[:, :, i], | |
| previous_tokens=( | |
| previous_tokens[i + 1] if previous_tokens is not None else None | |
| ), | |
| **sampling_kwargs, | |
| )[0] | |
| ) | |
| return torch.stack(codebooks, dim=0) | |
| def decode_n_tokens( | |
| model: NaiveTransformer, | |
| cur_token: torch.Tensor, | |
| input_pos: torch.Tensor, | |
| num_new_tokens: int, | |
| eos_token_id: int = 2, | |
| im_end_id: int = 4, | |
| decode_one_token=decode_one_token_naive, | |
| **sampling_kwargs, | |
| ): | |
| previous_tokens = torch.zeros( | |
| (model.config.num_codebooks + 1, model.config.max_seq_len), | |
| dtype=torch.int, | |
| device=cur_token.device, | |
| ) | |
| for i in tqdm(range(num_new_tokens)): | |
| # We need to get windowed repeat penalty | |
| win_size = 16 | |
| if i < win_size: | |
| window = previous_tokens[:, :win_size] | |
| else: | |
| window = previous_tokens[:, i - win_size : i] | |
| with torch.backends.cuda.sdp_kernel( | |
| enable_flash=False, enable_mem_efficient=False, enable_math=True | |
| ): # Actually better for Inductor to codegen attention here | |
| next_token = decode_one_token( | |
| model=model, | |
| x=cur_token, | |
| input_pos=input_pos, | |
| previous_tokens=window, | |
| **sampling_kwargs, | |
| ) | |
| input_pos += 1 | |
| cur_token = next_token.view(1, model.config.num_codebooks + 1, -1) | |
| previous_tokens[:, i : i + 1] = next_token.view( | |
| model.config.num_codebooks + 1, -1 | |
| ) | |
| if ( | |
| cur_token[0, 0, -1] == eos_token_id | |
| or cur_token[0, 0, -1] == im_end_id | |
| or (cur_token[0, 1:, -1] == CODEBOOK_EOS_TOKEN_ID).any() | |
| ): | |
| break | |
| return previous_tokens[:, : i + 1] | |
| def generate( | |
| *, | |
| model: NaiveTransformer, | |
| prompt: torch.Tensor, | |
| max_new_tokens: int, | |
| eos_token_id: int = 2, | |
| im_end_id: int = 4, | |
| decode_one_token=decode_one_token_naive, | |
| **sampling_kwargs, | |
| ) -> torch.Tensor: | |
| """ | |
| Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested. | |
| """ | |
| # create an empty tensor of the expected final shape and fill in the current tokens | |
| T = prompt.size(1) | |
| if max_new_tokens: | |
| if T + max_new_tokens > model.config.max_seq_len: | |
| max_new_tokens = model.config.max_seq_len - T | |
| logger.info(f"Truncating max_new_tokens to {max_new_tokens}") | |
| T_new = T + max_new_tokens | |
| else: | |
| T_new = model.config.max_seq_len | |
| max_new_tokens = T_new - T | |
| device, dtype = prompt.device, prompt.dtype | |
| with torch.device(device): | |
| model.setup_caches( | |
| max_batch_size=1, max_seq_len=T_new, dtype=next(model.parameters()).dtype | |
| ) | |
| codebook_dim = 1 + model.config.num_codebooks | |
| # create an empty tensor of the expected final shape and fill in the current tokens | |
| empty = torch.empty((codebook_dim, T_new), dtype=dtype, device=device) | |
| empty[:, :T] = prompt | |
| seq = empty | |
| input_pos = torch.arange(0, T, device=device) | |
| # Use non-accelerated version for now, to avoid compilation overhead | |
| prefill_decode = ( | |
| decode_one_token_naive | |
| if isinstance(model, NaiveTransformer) | |
| else decode_one_token_ar | |
| ) | |
| next_token = prefill_decode( | |
| model, prompt.view(1, codebook_dim, -1), input_pos, **sampling_kwargs | |
| ) | |
| seq[:, T : T + 1] = next_token | |
| input_pos = torch.tensor([T], device=device, dtype=torch.int) | |
| x = decode_n_tokens( | |
| model, | |
| next_token.view(1, codebook_dim, -1), | |
| input_pos, | |
| max_new_tokens - 1, | |
| eos_token_id=eos_token_id, | |
| im_end_id=im_end_id, | |
| decode_one_token=decode_one_token, | |
| **sampling_kwargs, | |
| ) | |
| # x = torch.cat(generated_tokens, dim=1) | |
| seq = seq[:, : T + 1 + x.size(1)] | |
| seq[:, T + 1 :] = x | |
| return seq | |
| def encode_tokens( | |
| tokenizer, | |
| string, | |
| bos=True, | |
| device="cuda", | |
| prompt_tokens=None, | |
| speaker=None, | |
| num_codebooks=4, | |
| ): | |
| string = clean_text(string) | |
| if speaker is None: | |
| speaker = "assistant" | |
| string = ( | |
| f"<|im_start|>user<|im_sep|>{string}<|im_end|><|im_start|>{speaker}<|im_sep|>" | |
| ) | |
| if bos: | |
| string = f"<|begin_of_sequence|>{string}" | |
| new_tokens = tokenizer.encode( | |
| string, | |
| add_special_tokens=False, | |
| max_length=10**6, | |
| truncation=False, | |
| ) | |
| tokens = torch.tensor([new_tokens], dtype=torch.int, device=device) | |
| # Codebooks | |
| zeros = ( | |
| torch.ones((num_codebooks, tokens.size(1)), dtype=torch.int, device=device) | |
| * CODEBOOK_PAD_TOKEN_ID | |
| ) | |
| prompt = torch.cat((tokens, zeros), dim=0) | |
| if prompt_tokens is None: | |
| return prompt | |
| # Get prompt tokens | |
| if prompt_tokens.ndim == 3: | |
| assert ( | |
| prompt_tokens.shape[0] == 1 | |
| ), f"3 dim prompt tokens should have shape (1, num_codebooks, seq_len)" | |
| prompt_tokens = prompt_tokens[0] | |
| assert prompt_tokens.ndim == 2 | |
| data = prompt_tokens + 2 | |
| if prompt_tokens.shape[0] > num_codebooks: | |
| logger.warning( | |
| f"Prompt tokens shape {prompt_tokens.shape} is larger than num_codebooks {num_codebooks}, getting first {num_codebooks} codebooks" | |
| ) | |
| data = data[:num_codebooks] | |
| # Add eos token for each codebook | |
| data = torch.cat( | |
| ( | |
| data, | |
| torch.ones((data.size(0), 1), dtype=torch.int, device=device) | |
| * CODEBOOK_EOS_TOKEN_ID, | |
| ), | |
| dim=1, | |
| ) | |
| # Since 1.0, we use <|semantic|> | |
| s0_token_id = tokenizer.convert_tokens_to_ids("<|semantic|>") | |
| end_token_id = tokenizer.convert_tokens_to_ids("<|im_end|>") | |
| main_token_ids = ( | |
| torch.ones((1, data.size(1)), dtype=torch.int, device=device) * s0_token_id | |
| ) | |
| main_token_ids[0, -1] = end_token_id | |
| data = torch.cat((main_token_ids, data), dim=0) | |
| prompt = torch.cat((prompt, data), dim=1) | |
| return prompt | |
| def load_model( | |
| config_name, checkpoint_path, device, precision, max_length, compile=False | |
| ): | |
| hydra.core.global_hydra.GlobalHydra.instance().clear() | |
| with initialize(version_base="1.3", config_path="../../fish_speech/configs/model"): | |
| cfg = compose( | |
| config_name=config_name, overrides=[f"config.max_seq_len={max_length}"] | |
| ) | |
| model: Union[NaiveTransformer, DualARTransformer] = instantiate(cfg) | |
| if "int8" in str(checkpoint_path): | |
| logger.info("Using int8 weight-only quantization!") | |
| from quantize import WeightOnlyInt8QuantHandler | |
| simple_quantizer = WeightOnlyInt8QuantHandler(model) | |
| model = simple_quantizer.convert_for_runtime() | |
| if "int4" in str(checkpoint_path): | |
| logger.info("Using int4 quantization!") | |
| path_comps = checkpoint_path.name.split(".") | |
| assert path_comps[-2].startswith("g") | |
| groupsize = int(path_comps[-2][1:]) | |
| from quantize import WeightOnlyInt4QuantHandler | |
| simple_quantizer = WeightOnlyInt4QuantHandler(model, groupsize) | |
| model = simple_quantizer.convert_for_runtime() | |
| checkpoint = torch.load(str(checkpoint_path), map_location="cpu") | |
| if "state_dict" in checkpoint: | |
| checkpoint = checkpoint["state_dict"] | |
| if any(k.startswith("model.") for k in checkpoint): | |
| checkpoint = { | |
| k.replace("model.", ""): v | |
| for k, v in checkpoint.items() | |
| if k.startswith("model.") | |
| } | |
| model.load_state_dict(checkpoint, assign=True) | |
| model = model.to(device=device, dtype=precision) | |
| logger.info("Restored model from checkpoint") | |
| if isinstance(model, DualARTransformer): | |
| decode_one_token = decode_one_token_ar | |
| logger.info("Using DualARTransformer") | |
| else: | |
| decode_one_token = decode_one_token_naive | |
| logger.info("Using NaiveTransformer") | |
| if compile: | |
| logger.info("Compiling function...") | |
| decode_one_token = torch.compile( | |
| decode_one_token, mode="reduce-overhead", fullgraph=True | |
| ) | |
| return model.eval(), decode_one_token | |
| def split_text(text, min_length): | |
| text = clean_text(text) | |
| segments = [] | |
| curr = "" | |
| for char in text: | |
| curr += char | |
| if char not in [".", ",", "!", "?"]: | |
| continue | |
| if len(curr) >= min_length: | |
| segments.append(curr) | |
| curr = "" | |
| if curr: | |
| segments.append(curr) | |
| return segments | |
| def generate_long( | |
| *, | |
| model, | |
| tokenizer: callable, | |
| device: str | torch.device, | |
| decode_one_token: callable, | |
| text: str, | |
| num_samples: int = 1, | |
| max_new_tokens: int = 0, | |
| top_k: int = None, | |
| top_p: int = 0.7, | |
| repetition_penalty: float = 1.5, | |
| temperature: float = 0.7, | |
| compile: bool = False, | |
| iterative_prompt: bool = True, | |
| max_length: int = 2048, | |
| chunk_length: int = 30, | |
| speaker: Optional[str] = None, | |
| prompt_text: Optional[str] = None, | |
| prompt_tokens: Optional[torch.Tensor] = None, | |
| is_streaming: bool = False, | |
| ): | |
| model_size = sum(p.numel() for p in model.parameters() if p.requires_grad) | |
| im_end_id = tokenizer.convert_tokens_to_ids("<|im_end|>") | |
| use_prompt = prompt_text is not None and prompt_tokens is not None | |
| encoded = [] | |
| texts = split_text(text, chunk_length) if iterative_prompt else [text] | |
| if use_prompt: | |
| encoded.append( | |
| encode_tokens( | |
| tokenizer, | |
| prompt_text, | |
| prompt_tokens=prompt_tokens, | |
| bos=True, | |
| device=device, | |
| speaker=speaker, | |
| num_codebooks=model.config.num_codebooks, | |
| ) | |
| ) | |
| for idx, text in enumerate(texts): | |
| encoded.append( | |
| encode_tokens( | |
| tokenizer, | |
| string=text, | |
| bos=idx == 0 and not use_prompt, | |
| device=device, | |
| speaker=speaker, | |
| num_codebooks=model.config.num_codebooks, | |
| ) | |
| ) | |
| logger.info(f"Encoded text: {text}") | |
| for sample_idx in range(num_samples): | |
| torch.cuda.synchronize() | |
| global_encoded = [] | |
| all_codes = [] | |
| seg_idx = 0 | |
| if use_prompt: | |
| seg_idx = 1 | |
| global_encoded.append(encoded[0]) | |
| while seg_idx < len(encoded): | |
| logger.info( | |
| f"Generating sentence {seg_idx + 1}/{len(encoded)} of sample {sample_idx + 1}/{num_samples}" | |
| ) | |
| seg = encoded[seg_idx] | |
| global_encoded.append(seg) | |
| lengths = reversed([seg.size(1) for seg in global_encoded]) | |
| # Pick last 2000 tokens | |
| count = 0 | |
| for i, length in enumerate(lengths): | |
| count += length | |
| if count + length > max_length - 1024: | |
| break | |
| if i != 0 and i % 2 == 0: | |
| i -= 1 | |
| # Rotate the list, always make sure first segment is included to avoid drift | |
| if i < len(global_encoded) - 2: | |
| partial_encoded = global_encoded[:2] + global_encoded[-i:] | |
| else: | |
| partial_encoded = global_encoded | |
| cat_encoded = torch.cat(partial_encoded, dim=1) | |
| prompt_length = cat_encoded.size(1) | |
| t0 = time.perf_counter() | |
| y = generate( | |
| model=model, | |
| prompt=cat_encoded, | |
| max_new_tokens=max_new_tokens, | |
| eos_token_id=tokenizer.eos_token_id, | |
| im_end_id=im_end_id, | |
| decode_one_token=decode_one_token, | |
| temperature=temperature, | |
| top_k=top_k, | |
| top_p=top_p, | |
| repetition_penalty=repetition_penalty, | |
| ) | |
| if sample_idx == 0 and seg_idx == 0 and compile: | |
| logger.info(f"Compilation time: {time.perf_counter() - t0:.2f} seconds") | |
| torch.cuda.synchronize() | |
| t = time.perf_counter() - t0 | |
| tokens_generated = y.size(1) - prompt_length | |
| tokens_sec = tokens_generated / t | |
| logger.info( | |
| f"Generated {tokens_generated} tokens in {t:.02f} seconds, {tokens_sec:.02f} tokens/sec" | |
| ) | |
| logger.info( | |
| f"Bandwidth achieved: {model_size * tokens_sec / 1e9:.02f} GB/s" | |
| ) | |
| logger.info( | |
| f"GPU Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB" | |
| ) | |
| # Put the generated tokens | |
| # since there is <im_end> and <eos> tokens, we remove last 2 tokens | |
| codes = y[1:, prompt_length:-2].clone() | |
| codes = codes - 2 | |
| assert (codes >= 0).all(), f"Negative code found" | |
| decoded = y[:, prompt_length:-1].clone() | |
| if decoded[0, -1] != im_end_id: # <im_end> | |
| val = [[im_end_id]] + [[CODEBOOK_EOS_TOKEN_ID]] * (decoded.size(0) - 1) | |
| decoded = torch.cat( | |
| (decoded, torch.tensor(val, device=device, dtype=torch.int)), dim=1 | |
| ) | |
| # But for global encoding, we should keep the <im_end> token | |
| global_encoded.append(decoded) | |
| if is_streaming: | |
| assert (codes >= 0).all(), f"Negative code found: {codes}" | |
| yield codes | |
| else: | |
| all_codes.append(codes) | |
| seg_idx += 1 | |
| if is_streaming: | |
| # This indicates the end of the current sample | |
| yield None | |
| else: | |
| all_codes = torch.cat(all_codes, dim=1) | |
| assert (all_codes >= 0).all(), f"Negative code found: {codes}" | |
| yield all_codes | |
| def launch_thread_safe_queue( | |
| config_name, | |
| checkpoint_path, | |
| device, | |
| precision, | |
| max_length, | |
| compile=False, | |
| ): | |
| input_queue = queue.Queue() | |
| def worker(): | |
| model, decode_one_token = load_model( | |
| config_name, checkpoint_path, device, precision, max_length, compile=compile | |
| ) | |
| while True: | |
| item = input_queue.get() | |
| if item is None: | |
| break | |
| kwargs = item["request"] | |
| event = item["event"] | |
| try: | |
| item["success"] = True | |
| item["response"] = list( | |
| generate_long( | |
| model=model, decode_one_token=decode_one_token, **kwargs | |
| ) | |
| ) | |
| except Exception as e: | |
| item["success"] = False | |
| item["response"] = e | |
| event.set() | |
| threading.Thread(target=worker, daemon=True).start() | |
| return input_queue | |
| def main( | |
| text: str, | |
| prompt_text: Optional[str], | |
| prompt_tokens: Optional[Path], | |
| num_samples: int, | |
| max_new_tokens: int, | |
| top_k: int, | |
| top_p: int, | |
| repetition_penalty: float, | |
| temperature: float, | |
| checkpoint_path: Path, | |
| config_name: str, | |
| tokenizer: str, | |
| compile: bool, | |
| seed: int, | |
| speaker: Optional[str], | |
| half: bool, | |
| iterative_prompt: bool, | |
| max_length: int, | |
| chunk_length: int, | |
| ) -> None: | |
| device = "cuda" | |
| precision = torch.half if half else torch.bfloat16 | |
| logger.info("Loading model ...") | |
| t0 = time.time() | |
| model, decode_one_token = load_model( | |
| config_name, checkpoint_path, device, precision, max_length, compile=compile | |
| ) | |
| torch.cuda.synchronize() | |
| logger.info(f"Time to load model: {time.time() - t0:.02f} seconds") | |
| prompt_tokens = ( | |
| torch.from_numpy(np.load(prompt_tokens)).to(device) | |
| if prompt_tokens is not None | |
| else None | |
| ) | |
| tokenizer = AutoTokenizer.from_pretrained(tokenizer) | |
| torch.manual_seed(seed) | |
| torch.cuda.manual_seed(seed) | |
| generator = generate_long( | |
| model=model, | |
| device=device, | |
| decode_one_token=decode_one_token, | |
| text=text, | |
| num_samples=num_samples, | |
| max_new_tokens=max_new_tokens, | |
| top_k=top_k, | |
| top_p=top_p, | |
| repetition_penalty=repetition_penalty, | |
| temperature=temperature, | |
| tokenizer=tokenizer, | |
| compile=compile, | |
| speaker=speaker, | |
| iterative_prompt=iterative_prompt, | |
| max_length=max_length, | |
| chunk_length=chunk_length, | |
| prompt_text=prompt_text, | |
| prompt_tokens=prompt_tokens, | |
| ) | |
| for idx, codes in enumerate(generator): | |
| np.save(f"codes_{idx}.npy", codes.cpu().numpy()) | |
| logger.info(f"Saved codes to codes_{idx}.npy") | |
| if __name__ == "__main__": | |
| main() | |