Spaces:
Running
on
L4
Running
on
L4
lengyue233
commited on
Optimize graph
Browse files- app.py +5 -10
- tools/llama/generate.py +37 -26
app.py
CHANGED
@@ -41,6 +41,9 @@ Related code are released under BSD-3-Clause License, and weights are released u
|
|
41 |
|
42 |
We are not responsible for any misuse of the model, please consider your local laws and regulations before using it.
|
43 |
我们不对模型的任何滥用负责,请在使用之前考虑您当地的法律法规.
|
|
|
|
|
|
|
44 |
"""
|
45 |
|
46 |
TEXTBOX_PLACEHOLDER = """Put your text here. 在此处输入文本."""
|
@@ -76,7 +79,6 @@ def inference(
|
|
76 |
reference_text,
|
77 |
max_new_tokens,
|
78 |
chunk_length,
|
79 |
-
top_k,
|
80 |
top_p,
|
81 |
repetition_penalty,
|
82 |
temperature,
|
@@ -112,7 +114,6 @@ def inference(
|
|
112 |
device=vqgan_model.device,
|
113 |
max_new_tokens=max_new_tokens,
|
114 |
text=text,
|
115 |
-
top_k=int(top_k) if top_k > 0 else None,
|
116 |
top_p=top_p,
|
117 |
repetition_penalty=repetition_penalty,
|
118 |
temperature=temperature,
|
@@ -194,10 +195,6 @@ def build_app():
|
|
194 |
step=8,
|
195 |
)
|
196 |
|
197 |
-
top_k = gr.Slider(
|
198 |
-
label="Top-K", minimum=0, maximum=5, value=0, step=1
|
199 |
-
)
|
200 |
-
|
201 |
top_p = gr.Slider(
|
202 |
label="Top-P", minimum=0, maximum=1, value=0.7, step=0.01
|
203 |
)
|
@@ -264,7 +261,6 @@ def build_app():
|
|
264 |
reference_text,
|
265 |
max_new_tokens,
|
266 |
chunk_length,
|
267 |
-
top_k,
|
268 |
top_p,
|
269 |
repetition_penalty,
|
270 |
temperature,
|
@@ -310,8 +306,8 @@ if __name__ == "__main__":
|
|
310 |
args.compile = True
|
311 |
args.max_gradio_length = 1024
|
312 |
args.tokenizer = "./checkpoints/fish-speech-1"
|
313 |
-
args.llama_checkpoint_path = "./checkpoints/fish-speech-1/text2semantic-sft-
|
314 |
-
args.llama_config_name = "
|
315 |
args.vqgan_checkpoint_path = "./checkpoints/fish-speech-1/vq-gan-group-fsq-2x1024.pth"
|
316 |
args.vqgan_config_name = "vqgan_pretrain"
|
317 |
|
@@ -343,7 +339,6 @@ if __name__ == "__main__":
|
|
343 |
reference_text="",
|
344 |
max_new_tokens=0,
|
345 |
chunk_length=0,
|
346 |
-
top_k=0, # 0 means no limit
|
347 |
top_p=0.7,
|
348 |
repetition_penalty=1.5,
|
349 |
temperature=0.7,
|
|
|
41 |
|
42 |
We are not responsible for any misuse of the model, please consider your local laws and regulations before using it.
|
43 |
我们不对模型的任何滥用负责,请在使用之前考虑您当地的法律法规.
|
44 |
+
|
45 |
+
The model running in this WebUI is Fish Speech V1 Medium SFT 4K.
|
46 |
+
在此 WebUI 中运行的模型是 Fish Speech V1 Medium SFT 4K.
|
47 |
"""
|
48 |
|
49 |
TEXTBOX_PLACEHOLDER = """Put your text here. 在此处输入文本."""
|
|
|
79 |
reference_text,
|
80 |
max_new_tokens,
|
81 |
chunk_length,
|
|
|
82 |
top_p,
|
83 |
repetition_penalty,
|
84 |
temperature,
|
|
|
114 |
device=vqgan_model.device,
|
115 |
max_new_tokens=max_new_tokens,
|
116 |
text=text,
|
|
|
117 |
top_p=top_p,
|
118 |
repetition_penalty=repetition_penalty,
|
119 |
temperature=temperature,
|
|
|
195 |
step=8,
|
196 |
)
|
197 |
|
|
|
|
|
|
|
|
|
198 |
top_p = gr.Slider(
|
199 |
label="Top-P", minimum=0, maximum=1, value=0.7, step=0.01
|
200 |
)
|
|
|
261 |
reference_text,
|
262 |
max_new_tokens,
|
263 |
chunk_length,
|
|
|
264 |
top_p,
|
265 |
repetition_penalty,
|
266 |
temperature,
|
|
|
306 |
args.compile = True
|
307 |
args.max_gradio_length = 1024
|
308 |
args.tokenizer = "./checkpoints/fish-speech-1"
|
309 |
+
args.llama_checkpoint_path = "./checkpoints/fish-speech-1/text2semantic-sft-medium-v1-4k.pth"
|
310 |
+
args.llama_config_name = "dual_ar_2_codebook_medium"
|
311 |
args.vqgan_checkpoint_path = "./checkpoints/fish-speech-1/vq-gan-group-fsq-2x1024.pth"
|
312 |
args.vqgan_config_name = "vqgan_pretrain"
|
313 |
|
|
|
339 |
reference_text="",
|
340 |
max_new_tokens=0,
|
341 |
chunk_length=0,
|
|
|
342 |
top_p=0.7,
|
343 |
repetition_penalty=1.5,
|
344 |
temperature=0.7,
|
tools/llama/generate.py
CHANGED
@@ -42,11 +42,11 @@ def multinomial_sample_one_no_sync(
|
|
42 |
def logits_to_probs(
|
43 |
logits,
|
44 |
previous_tokens: Optional[torch.Tensor] = None,
|
45 |
-
temperature:
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
if previous_tokens is not None:
|
51 |
previous_tokens = previous_tokens.long()
|
52 |
score = torch.gather(logits, dim=0, index=previous_tokens)
|
@@ -55,11 +55,9 @@ def logits_to_probs(
|
|
55 |
)
|
56 |
logits.scatter_(dim=0, index=previous_tokens, src=score)
|
57 |
|
58 |
-
#
|
59 |
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
|
60 |
-
cum_probs = torch.cumsum(
|
61 |
-
torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1
|
62 |
-
)
|
63 |
sorted_indices_to_remove = cum_probs > top_p
|
64 |
sorted_indices_to_remove[0] = False # keep at least one option
|
65 |
indices_to_remove = sorted_indices_to_remove.scatter(
|
@@ -69,11 +67,6 @@ def logits_to_probs(
|
|
69 |
|
70 |
logits = logits / max(temperature, 1e-5)
|
71 |
|
72 |
-
# if top_k is not None:
|
73 |
-
# v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
|
74 |
-
# pivot = v.select(-1, -1).unsqueeze(-1)
|
75 |
-
# logits = torch.where(logits < pivot, -float("Inf"), logits)
|
76 |
-
|
77 |
probs = torch.nn.functional.softmax(logits, dim=-1)
|
78 |
return probs
|
79 |
|
@@ -449,7 +442,6 @@ def generate_long(
|
|
449 |
text: str,
|
450 |
num_samples: int = 1,
|
451 |
max_new_tokens: int = 0,
|
452 |
-
top_k: int = None,
|
453 |
top_p: int = 0.7,
|
454 |
repetition_penalty: float = 1.5,
|
455 |
temperature: float = 0.7,
|
@@ -462,6 +454,10 @@ def generate_long(
|
|
462 |
prompt_tokens: Optional[torch.Tensor] = None,
|
463 |
is_streaming: bool = False,
|
464 |
):
|
|
|
|
|
|
|
|
|
465 |
model_size = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
466 |
im_end_id = tokenizer.convert_tokens_to_ids("<|im_end|>")
|
467 |
|
@@ -493,8 +489,18 @@ def generate_long(
|
|
493 |
)
|
494 |
logger.info(f"Encoded text: {text}")
|
495 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
496 |
for sample_idx in range(num_samples):
|
497 |
-
torch.cuda.
|
|
|
|
|
498 |
global_encoded = []
|
499 |
all_codes = []
|
500 |
seg_idx = 0
|
@@ -540,7 +546,6 @@ def generate_long(
|
|
540 |
im_end_id=im_end_id,
|
541 |
decode_one_token=decode_one_token,
|
542 |
temperature=temperature,
|
543 |
-
top_k=top_k,
|
544 |
top_p=top_p,
|
545 |
repetition_penalty=repetition_penalty,
|
546 |
)
|
@@ -548,7 +553,9 @@ def generate_long(
|
|
548 |
if sample_idx == 0 and seg_idx == 0 and compile:
|
549 |
logger.info(f"Compilation time: {time.perf_counter() - t0:.2f} seconds")
|
550 |
|
551 |
-
torch.cuda.
|
|
|
|
|
552 |
t = time.perf_counter() - t0
|
553 |
|
554 |
tokens_generated = y.size(1) - prompt_length
|
@@ -559,9 +566,11 @@ def generate_long(
|
|
559 |
logger.info(
|
560 |
f"Bandwidth achieved: {model_size * tokens_sec / 1e9:.02f} GB/s"
|
561 |
)
|
562 |
-
|
563 |
-
|
564 |
-
|
|
|
|
|
565 |
|
566 |
# Put the generated tokens
|
567 |
# since there is <im_end> and <eos> tokens, we remove last 2 tokens
|
@@ -654,7 +663,6 @@ def launch_thread_safe_queue(
|
|
654 |
)
|
655 |
@click.option("--num-samples", type=int, default=1)
|
656 |
@click.option("--max-new-tokens", type=int, default=0)
|
657 |
-
@click.option("--top-k", type=int, default=None)
|
658 |
@click.option("--top-p", type=float, default=0.7)
|
659 |
@click.option("--repetition-penalty", type=float, default=1.5)
|
660 |
@click.option("--temperature", type=float, default=0.7)
|
@@ -678,7 +686,6 @@ def main(
|
|
678 |
prompt_tokens: Optional[Path],
|
679 |
num_samples: int,
|
680 |
max_new_tokens: int,
|
681 |
-
top_k: int,
|
682 |
top_p: int,
|
683 |
repetition_penalty: float,
|
684 |
temperature: float,
|
@@ -702,7 +709,10 @@ def main(
|
|
702 |
model, decode_one_token = load_model(
|
703 |
config_name, checkpoint_path, device, precision, max_length, compile=compile
|
704 |
)
|
705 |
-
|
|
|
|
|
|
|
706 |
logger.info(f"Time to load model: {time.time() - t0:.02f} seconds")
|
707 |
|
708 |
prompt_tokens = (
|
@@ -713,7 +723,9 @@ def main(
|
|
713 |
|
714 |
tokenizer = AutoTokenizer.from_pretrained(tokenizer)
|
715 |
torch.manual_seed(seed)
|
716 |
-
|
|
|
|
|
717 |
|
718 |
generator = generate_long(
|
719 |
model=model,
|
@@ -722,7 +734,6 @@ def main(
|
|
722 |
text=text,
|
723 |
num_samples=num_samples,
|
724 |
max_new_tokens=max_new_tokens,
|
725 |
-
top_k=top_k,
|
726 |
top_p=top_p,
|
727 |
repetition_penalty=repetition_penalty,
|
728 |
temperature=temperature,
|
|
|
42 |
def logits_to_probs(
|
43 |
logits,
|
44 |
previous_tokens: Optional[torch.Tensor] = None,
|
45 |
+
temperature: torch.Tensor = 1.0,
|
46 |
+
top_p: torch.Tensor = 1.0,
|
47 |
+
repetition_penalty: torch.Tensor = 1.0,
|
48 |
+
) -> torch.Tensor:
|
49 |
+
# Apply repetition penalty
|
50 |
if previous_tokens is not None:
|
51 |
previous_tokens = previous_tokens.long()
|
52 |
score = torch.gather(logits, dim=0, index=previous_tokens)
|
|
|
55 |
)
|
56 |
logits.scatter_(dim=0, index=previous_tokens, src=score)
|
57 |
|
58 |
+
# Apply top-p sampling
|
59 |
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
|
60 |
+
cum_probs = torch.cumsum(torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1)
|
|
|
|
|
61 |
sorted_indices_to_remove = cum_probs > top_p
|
62 |
sorted_indices_to_remove[0] = False # keep at least one option
|
63 |
indices_to_remove = sorted_indices_to_remove.scatter(
|
|
|
67 |
|
68 |
logits = logits / max(temperature, 1e-5)
|
69 |
|
|
|
|
|
|
|
|
|
|
|
70 |
probs = torch.nn.functional.softmax(logits, dim=-1)
|
71 |
return probs
|
72 |
|
|
|
442 |
text: str,
|
443 |
num_samples: int = 1,
|
444 |
max_new_tokens: int = 0,
|
|
|
445 |
top_p: int = 0.7,
|
446 |
repetition_penalty: float = 1.5,
|
447 |
temperature: float = 0.7,
|
|
|
454 |
prompt_tokens: Optional[torch.Tensor] = None,
|
455 |
is_streaming: bool = False,
|
456 |
):
|
457 |
+
assert 0 < top_p <= 1, "top_p must be in (0, 1]"
|
458 |
+
assert 0 < repetition_penalty < 2, "repetition_penalty must be in (0, 2)"
|
459 |
+
assert 0 < temperature < 2, "temperature must be in (0, 2)"
|
460 |
+
|
461 |
model_size = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
462 |
im_end_id = tokenizer.convert_tokens_to_ids("<|im_end|>")
|
463 |
|
|
|
489 |
)
|
490 |
logger.info(f"Encoded text: {text}")
|
491 |
|
492 |
+
# Move temperature, top_p, repetition_penalty to device
|
493 |
+
# This is important so that changing params doesn't trigger recompile
|
494 |
+
temperature = torch.tensor(temperature, device=device, dtype=torch.float)
|
495 |
+
top_p = torch.tensor(top_p, device=device, dtype=torch.float)
|
496 |
+
repetition_penalty = torch.tensor(
|
497 |
+
repetition_penalty, device=device, dtype=torch.float
|
498 |
+
)
|
499 |
+
|
500 |
for sample_idx in range(num_samples):
|
501 |
+
if torch.cuda.is_available():
|
502 |
+
torch.cuda.synchronize()
|
503 |
+
|
504 |
global_encoded = []
|
505 |
all_codes = []
|
506 |
seg_idx = 0
|
|
|
546 |
im_end_id=im_end_id,
|
547 |
decode_one_token=decode_one_token,
|
548 |
temperature=temperature,
|
|
|
549 |
top_p=top_p,
|
550 |
repetition_penalty=repetition_penalty,
|
551 |
)
|
|
|
553 |
if sample_idx == 0 and seg_idx == 0 and compile:
|
554 |
logger.info(f"Compilation time: {time.perf_counter() - t0:.2f} seconds")
|
555 |
|
556 |
+
if torch.cuda.is_available():
|
557 |
+
torch.cuda.synchronize()
|
558 |
+
|
559 |
t = time.perf_counter() - t0
|
560 |
|
561 |
tokens_generated = y.size(1) - prompt_length
|
|
|
566 |
logger.info(
|
567 |
f"Bandwidth achieved: {model_size * tokens_sec / 1e9:.02f} GB/s"
|
568 |
)
|
569 |
+
|
570 |
+
if torch.cuda.is_available():
|
571 |
+
logger.info(
|
572 |
+
f"GPU Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB"
|
573 |
+
)
|
574 |
|
575 |
# Put the generated tokens
|
576 |
# since there is <im_end> and <eos> tokens, we remove last 2 tokens
|
|
|
663 |
)
|
664 |
@click.option("--num-samples", type=int, default=1)
|
665 |
@click.option("--max-new-tokens", type=int, default=0)
|
|
|
666 |
@click.option("--top-p", type=float, default=0.7)
|
667 |
@click.option("--repetition-penalty", type=float, default=1.5)
|
668 |
@click.option("--temperature", type=float, default=0.7)
|
|
|
686 |
prompt_tokens: Optional[Path],
|
687 |
num_samples: int,
|
688 |
max_new_tokens: int,
|
|
|
689 |
top_p: int,
|
690 |
repetition_penalty: float,
|
691 |
temperature: float,
|
|
|
709 |
model, decode_one_token = load_model(
|
710 |
config_name, checkpoint_path, device, precision, max_length, compile=compile
|
711 |
)
|
712 |
+
|
713 |
+
if torch.cuda.is_available():
|
714 |
+
torch.cuda.synchronize()
|
715 |
+
|
716 |
logger.info(f"Time to load model: {time.time() - t0:.02f} seconds")
|
717 |
|
718 |
prompt_tokens = (
|
|
|
723 |
|
724 |
tokenizer = AutoTokenizer.from_pretrained(tokenizer)
|
725 |
torch.manual_seed(seed)
|
726 |
+
|
727 |
+
if torch.cuda.is_available():
|
728 |
+
torch.cuda.manual_seed(seed)
|
729 |
|
730 |
generator = generate_long(
|
731 |
model=model,
|
|
|
734 |
text=text,
|
735 |
num_samples=num_samples,
|
736 |
max_new_tokens=max_new_tokens,
|
|
|
737 |
top_p=top_p,
|
738 |
repetition_penalty=repetition_penalty,
|
739 |
temperature=temperature,
|