Spaces:
				
			
			
	
			
			
		Runtime error
		
	
	
	
			
			
	
	
	
	
		
		
		Runtime error
		
	
		XXXXRT666
		
	commited on
		
		
					Commit 
							
							Β·
						
						5cfeca6
	
1
								Parent(s):
							
							7bdf3c3
								
Cache CUDA Graph
Browse files- AR/models/structs.py +4 -6
- AR/models/t2s_model_abc.py +33 -13
- AR/models/t2s_model_flash_attn.py +62 -38
- README.md +1 -1
- inference_webui.py +6 -3
    	
        AR/models/structs.py
    CHANGED
    
    | @@ -1,3 +1,7 @@ | |
|  | |
|  | |
|  | |
|  | |
| 1 | 
             
            from __future__ import annotations
         | 
| 2 |  | 
| 3 | 
             
            from dataclasses import dataclass
         | 
| @@ -48,7 +52,6 @@ class T2SSession: | |
| 48 | 
             
                        self.y_len = y_len
         | 
| 49 |  | 
| 50 | 
             
                        # Cache
         | 
| 51 | 
            -
                        self.kv_cache = decoder.init_cache(bsz)
         | 
| 52 | 
             
                        self.sampler = Sampler(bsz, decoder.vocab_size)
         | 
| 53 |  | 
| 54 | 
             
                        # Forward args
         | 
| @@ -62,11 +65,6 @@ class T2SSession: | |
| 62 | 
             
                        self.input_pos = torch.zeros_like(self.prefill_len)
         | 
| 63 | 
             
                        self.input_pos.add_(self.prefill_len)
         | 
| 64 |  | 
| 65 | 
            -
                        # CUDA Graph
         | 
| 66 | 
            -
                        self.graph: Optional[torch.cuda.CUDAGraph] = None
         | 
| 67 | 
            -
                        self.xy_pos_ = torch.rand((bsz, 1, decoder.embedding_dim)).to(dtype)
         | 
| 68 | 
            -
                        self.xy_dec_ = torch.rand((bsz, 1, decoder.embedding_dim)).to(dtype)
         | 
| 69 | 
            -
             | 
| 70 | 
             
                        # EOS
         | 
| 71 | 
             
                        self.completed = torch.Tensor([False] * len(self.x)).bool().to(device)
         | 
| 72 | 
             
                        self.y_results: List[Tensor] = [None] * len(self.x)  # type: ignore
         | 
|  | |
| 1 | 
            +
            """
         | 
| 2 | 
            +
            Modified From https://github.com/XXXXRT666/GPT-SoVITS
         | 
| 3 | 
            +
            """
         | 
| 4 | 
            +
             | 
| 5 | 
             
            from __future__ import annotations
         | 
| 6 |  | 
| 7 | 
             
            from dataclasses import dataclass
         | 
|  | |
| 52 | 
             
                        self.y_len = y_len
         | 
| 53 |  | 
| 54 | 
             
                        # Cache
         | 
|  | |
| 55 | 
             
                        self.sampler = Sampler(bsz, decoder.vocab_size)
         | 
| 56 |  | 
| 57 | 
             
                        # Forward args
         | 
|  | |
| 65 | 
             
                        self.input_pos = torch.zeros_like(self.prefill_len)
         | 
| 66 | 
             
                        self.input_pos.add_(self.prefill_len)
         | 
| 67 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
| 68 | 
             
                        # EOS
         | 
| 69 | 
             
                        self.completed = torch.Tensor([False] * len(self.x)).bool().to(device)
         | 
| 70 | 
             
                        self.y_results: List[Tensor] = [None] * len(self.x)  # type: ignore
         | 
    	
        AR/models/t2s_model_abc.py
    CHANGED
    
    | @@ -1,9 +1,14 @@ | |
|  | |
|  | |
|  | |
|  | |
| 1 | 
             
            from __future__ import annotations
         | 
| 2 |  | 
| 3 | 
             
            import os
         | 
| 4 | 
             
            from abc import ABC, abstractmethod
         | 
| 5 | 
             
            from contextlib import nullcontext
         | 
| 6 | 
             
            from typing import Any, Dict, List, MutableSequence, Optional, Tuple, Type
         | 
|  | |
| 7 |  | 
| 8 | 
             
            import torch
         | 
| 9 | 
             
            import torch._inductor.config
         | 
| @@ -31,6 +36,7 @@ class Sampler(nn.Module): | |
| 31 | 
             
                    self.register_buffer("samples", torch.zeros((batch_size,), dtype=torch.int32), persistent=False)
         | 
| 32 |  | 
| 33 | 
             
                    self.__CUDAGraph: Optional[CUDAGraph] = None
         | 
|  | |
| 34 |  | 
| 35 | 
             
                def empty_cache(self):
         | 
| 36 | 
             
                    self.logits.zero_()
         | 
| @@ -139,6 +145,7 @@ class Sampler(nn.Module): | |
| 139 | 
             
                    return idx_next
         | 
| 140 |  | 
| 141 | 
             
                def capture(self, temperature: float, top_k: int, top_p: float):
         | 
|  | |
| 142 | 
             
                    s = torch.cuda.Stream()
         | 
| 143 | 
             
                    s.wait_stream(torch.cuda.current_stream())
         | 
| 144 |  | 
| @@ -153,7 +160,9 @@ class Sampler(nn.Module): | |
| 153 | 
             
                    with torch.cuda.graph(self.__CUDAGraph):
         | 
| 154 | 
             
                        self.samples = self.__sample_cuda_graph(logits, temperature, top_k, top_p)
         | 
| 155 | 
             
                    torch.cuda.synchronize()
         | 
|  | |
| 156 |  | 
|  | |
| 157 | 
             
                def sample(
         | 
| 158 | 
             
                    self,
         | 
| 159 | 
             
                    logits: Tensor,
         | 
| @@ -162,21 +171,32 @@ class Sampler(nn.Module): | |
| 162 | 
             
                    top_k: int,
         | 
| 163 | 
             
                    top_p: float,
         | 
| 164 | 
             
                    repetition_penalty: float,
         | 
| 165 | 
            -
                    use_cuda_graph=False,
         | 
| 166 | 
            -
                    idx=-1,
         | 
| 167 | 
             
                ) -> Tensor:
         | 
| 168 | 
            -
                     | 
| 169 | 
            -
             | 
| 170 | 
            -
             | 
| 171 | 
            -
                     | 
| 172 | 
            -
             | 
| 173 | 
            -
             | 
| 174 | 
            -
             | 
| 175 | 
            -
             | 
| 176 | 
            -
                     | 
| 177 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 178 |  | 
| 179 | 
            -
                    return samples
         | 
| 180 |  | 
| 181 |  | 
| 182 | 
             
            class KVCacheABC(ABC, nn.Module):
         | 
|  | |
| 1 | 
            +
            """
         | 
| 2 | 
            +
            Modified From https://github.com/XXXXRT666/GPT-SoVITS
         | 
| 3 | 
            +
            """
         | 
| 4 | 
            +
             | 
| 5 | 
             
            from __future__ import annotations
         | 
| 6 |  | 
| 7 | 
             
            import os
         | 
| 8 | 
             
            from abc import ABC, abstractmethod
         | 
| 9 | 
             
            from contextlib import nullcontext
         | 
| 10 | 
             
            from typing import Any, Dict, List, MutableSequence, Optional, Tuple, Type
         | 
| 11 | 
            +
            import time
         | 
| 12 |  | 
| 13 | 
             
            import torch
         | 
| 14 | 
             
            import torch._inductor.config
         | 
|  | |
| 36 | 
             
                    self.register_buffer("samples", torch.zeros((batch_size,), dtype=torch.int32), persistent=False)
         | 
| 37 |  | 
| 38 | 
             
                    self.__CUDAGraph: Optional[CUDAGraph] = None
         | 
| 39 | 
            +
                    
         | 
| 40 |  | 
| 41 | 
             
                def empty_cache(self):
         | 
| 42 | 
             
                    self.logits.zero_()
         | 
|  | |
| 145 | 
             
                    return idx_next
         | 
| 146 |  | 
| 147 | 
             
                def capture(self, temperature: float, top_k: int, top_p: float):
         | 
| 148 | 
            +
                    t1=time.perf_counter()
         | 
| 149 | 
             
                    s = torch.cuda.Stream()
         | 
| 150 | 
             
                    s.wait_stream(torch.cuda.current_stream())
         | 
| 151 |  | 
|  | |
| 160 | 
             
                    with torch.cuda.graph(self.__CUDAGraph):
         | 
| 161 | 
             
                        self.samples = self.__sample_cuda_graph(logits, temperature, top_k, top_p)
         | 
| 162 | 
             
                    torch.cuda.synchronize()
         | 
| 163 | 
            +
                    print("Sample",time.perf_counter()-t1)
         | 
| 164 |  | 
| 165 | 
            +
                # @torch.jit.script
         | 
| 166 | 
             
                def sample(
         | 
| 167 | 
             
                    self,
         | 
| 168 | 
             
                    logits: Tensor,
         | 
|  | |
| 171 | 
             
                    top_k: int,
         | 
| 172 | 
             
                    top_p: float,
         | 
| 173 | 
             
                    repetition_penalty: float,
         | 
|  | |
|  | |
| 174 | 
             
                ) -> Tensor:
         | 
| 175 | 
            +
                    
         | 
| 176 | 
            +
                    previous_tokens = previous_tokens.long()
         | 
| 177 | 
            +
                    score = torch.gather(logits, dim=1, index=previous_tokens)
         | 
| 178 | 
            +
                    score = torch.where(score < 0, score * repetition_penalty, score / repetition_penalty)
         | 
| 179 | 
            +
                    logits.scatter_(dim=1, index=previous_tokens, src=score)
         | 
| 180 | 
            +
             | 
| 181 | 
            +
                    sorted_logits, sorted_indices = torch.sort(logits, descending=True)
         | 
| 182 | 
            +
                    cum_probs = torch.cumsum(torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1)
         | 
| 183 | 
            +
                    sorted_indices_to_remove = cum_probs > top_p
         | 
| 184 | 
            +
                    sorted_indices_to_remove[:, 0] = False  # keep at least one option
         | 
| 185 | 
            +
                    indices_to_remove = sorted_indices_to_remove.scatter(dim=1, index=sorted_indices, src=sorted_indices_to_remove)
         | 
| 186 | 
            +
                    logits = logits.masked_fill(indices_to_remove, -float("Inf"))
         | 
| 187 | 
            +
             | 
| 188 | 
            +
                    logits = logits / max(temperature, 1e-5)
         | 
| 189 | 
            +
             | 
| 190 | 
            +
                    v, _ = torch.topk(logits, top_k)
         | 
| 191 | 
            +
                    pivot = v[:, -1].unsqueeze(-1)
         | 
| 192 | 
            +
                    logits = torch.where(logits < pivot, -float("Inf"), logits)
         | 
| 193 | 
            +
             | 
| 194 | 
            +
                    probs = torch.nn.functional.softmax(logits, dim=-1)
         | 
| 195 | 
            +
                    q = torch.empty_like(probs).exponential_(1.0)
         | 
| 196 | 
            +
                    idx_next = torch.argmax(probs / q, dim=-1, keepdim=True).to(dtype=torch.int32)
         | 
| 197 | 
            +
             | 
| 198 | 
            +
                    return idx_next
         | 
| 199 |  | 
|  | |
| 200 |  | 
| 201 |  | 
| 202 | 
             
            class KVCacheABC(ABC, nn.Module):
         | 
    	
        AR/models/t2s_model_flash_attn.py
    CHANGED
    
    | @@ -1,8 +1,12 @@ | |
| 1 | 
            -
             | 
|  | |
|  | |
|  | |
| 2 | 
             
            import os
         | 
| 3 | 
             
            import time
         | 
| 4 | 
             
            import traceback
         | 
| 5 | 
            -
            from typing import Dict, List, Tuple
         | 
|  | |
| 6 |  | 
| 7 | 
             
            import flash_attn  # type: ignore
         | 
| 8 | 
             
            import torch
         | 
| @@ -50,7 +54,7 @@ class Attention(AttentionABC): | |
| 50 |  | 
| 51 | 
             
                    attn: Tensor = flash_attn.flash_attn_with_kvcache(
         | 
| 52 | 
             
                        q, kv_cache.k_cache, kv_cache.v_cache, k, v, cache_seqlens=input_pos - 1
         | 
| 53 | 
            -
                    )
         | 
| 54 |  | 
| 55 | 
             
                    attn = self.dropout.forward(attn)
         | 
| 56 |  | 
| @@ -215,57 +219,66 @@ class CUDAGraphRunner: | |
| 215 |  | 
| 216 | 
             
                    self.decoder_path: os.PathLike
         | 
| 217 | 
             
                    self.decoder_model: T2SDecoderABC = decoder_model.to(self.device, self.dtype)
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 218 |  | 
| 219 | 
             
                def _handle_request(self, request: T2SRequest) -> List[torch.Tensor]:
         | 
| 220 | 
             
                    with self.device:
         | 
|  | |
|  | |
|  | |
| 221 | 
             
                        decoder = self.decoder_model
         | 
| 222 | 
             
                        session = T2SSession(decoder, request, device=self.device, dtype=self.dtype)
         | 
| 223 | 
            -
             | 
| 224 | 
            -
                         | 
| 225 | 
            -
                        bsz = y.size(0)
         | 
| 226 | 
             
                        t1 = 0.0
         | 
| 227 | 
            -
             | 
|  | |
| 228 | 
             
                        torch_profiler = TorchProfiler(request.debug)
         | 
| 229 | 
            -
             | 
| 230 | 
             
                        with torch_profiler.profiler():
         | 
| 231 | 
             
                            for idx in tqdm(range(1500)):
         | 
| 232 | 
             
                                if idx == 0:
         | 
| 233 | 
            -
                                    xy_dec = decoder.h.prefill(session.xy_pos, session.attn_mask_nested,  | 
| 234 | 
             
                                    xy_dec = torch.stack([t[[-1]] for t in xy_dec.unbind()])
         | 
| 235 | 
             
                                else:
         | 
| 236 | 
            -
                                    if request.use_cuda_graph and  | 
| 237 | 
            -
                                         | 
| 238 | 
             
                                        args, kwds = decoder.pre_forward(session)
         | 
| 239 | 
            -
                                         | 
| 240 | 
            -
                                             | 
| 241 | 
            -
                                             | 
| 242 | 
            -
                                             | 
| 243 | 
            -
                                            kv_caches= | 
| 244 | 
             
                                            *args,
         | 
| 245 | 
             
                                            **kwds,
         | 
| 246 | 
             
                                        )
         | 
| 247 |  | 
| 248 | 
             
                                    with torch_profiler.record("AR"):
         | 
| 249 | 
            -
                                        if  | 
| 250 | 
            -
                                             | 
| 251 | 
            -
                                             | 
| 252 | 
            -
                                            xy_dec =  | 
| 253 | 
             
                                        else:
         | 
| 254 | 
             
                                            args, kwds = decoder.pre_forward(session)
         | 
| 255 | 
             
                                            xy_dec = decoder.h.forward(
         | 
| 256 | 
            -
                                                 | 
| 257 | 
             
                                                session.xy_pos,
         | 
| 258 | 
            -
                                                 | 
| 259 | 
             
                                                *args,
         | 
| 260 | 
             
                                                **kwds,
         | 
| 261 | 
             
                                            )
         | 
|  | |
| 262 | 
             
                                decoder.post_forward(idx, session)
         | 
| 263 | 
             
                                logits = decoder.ar_predict_layer(xy_dec[:, -1])
         | 
| 264 | 
            -
                                 | 
| 265 |  | 
| 266 | 
             
                                if idx == 0:
         | 
| 267 | 
            -
                                    logits | 
| 268 | 
            -
             | 
| 269 | 
             
                                with torch_profiler.record("Sampling"):
         | 
| 270 | 
             
                                    samples = session.sampler.sample(
         | 
| 271 | 
             
                                        logits=logits,
         | 
| @@ -274,27 +287,26 @@ class CUDAGraphRunner: | |
| 274 | 
             
                                        top_p=request.top_p,
         | 
| 275 | 
             
                                        repetition_penalty=request.repetition_penalty,
         | 
| 276 | 
             
                                        temperature=request.temperature,
         | 
| 277 | 
            -
                                        use_cuda_graph=request.use_cuda_graph,
         | 
| 278 | 
            -
                                        idx=idx,
         | 
| 279 | 
             
                                    )
         | 
| 280 |  | 
| 281 | 
             
                                    session.y = torch.cat([session.y, samples], dim=1)
         | 
| 282 |  | 
|  | |
| 283 | 
             
                                with torch_profiler.record("EOS"):
         | 
| 284 | 
             
                                    argmax_token = torch.argmax(logits, dim=-1)
         | 
| 285 | 
             
                                    sample_token = samples.squeeze(1)
         | 
| 286 | 
             
                                    EOS_mask = (argmax_token == decoder.EOS) | (sample_token == decoder.EOS)
         | 
| 287 | 
            -
             | 
| 288 | 
             
                                    newly_done_mask = EOS_mask & (~session.completed)
         | 
| 289 | 
            -
                                with torch_profiler.record("EOS2"):
         | 
| 290 | 
             
                                    newly_done_indices = newly_done_mask.nonzero()
         | 
| 291 | 
            -
             | 
|  | |
| 292 | 
             
                                    if newly_done_indices.numel() > 0:
         | 
| 293 | 
             
                                        session.y_results[newly_done_indices[0]] = session.y[
         | 
| 294 | 
             
                                            newly_done_indices[0], session.y_len : -1
         | 
| 295 | 
             
                                        ].squeeze(0)
         | 
| 296 | 
             
                                        session.completed[newly_done_indices] = True
         | 
| 297 | 
            -
             | 
| 298 | 
             
                                    if torch.all(session.completed).item():
         | 
| 299 | 
             
                                        if session.y.size(1) == 0:
         | 
| 300 | 
             
                                            session.y = torch.cat([session.y, torch.zeros_like(samples)], dim=1)
         | 
| @@ -304,11 +316,12 @@ class CUDAGraphRunner: | |
| 304 | 
             
                                                f"T2S Decoding EOS {session.prefill_len.tolist().__str__().strip('[]')} -> \n{[i.size(0) for i in session.y_results].__str__().strip('[]')}"
         | 
| 305 | 
             
                                            )
         | 
| 306 | 
             
                                            tqdm.write(f"Infer Speed: {(idx - 1) / (time.perf_counter() - t1):.2f} token/s")
         | 
|  | |
| 307 | 
             
                                        break
         | 
| 308 | 
            -
             | 
| 309 | 
             
                                    if (
         | 
| 310 | 
            -
                                        request.early_stop_num != -1
         | 
| 311 | 
            -
                                        and (session.y.size(1) - session.y_len) > request.early_stop_num
         | 
| 312 | 
             
                                    ):
         | 
| 313 | 
             
                                        for i in range(bsz):
         | 
| 314 | 
             
                                            if not session.completed[i].item():
         | 
| @@ -318,14 +331,25 @@ class CUDAGraphRunner: | |
| 318 |  | 
| 319 | 
             
                                with torch_profiler.record("NextPos"):
         | 
| 320 | 
             
                                    y_emb = decoder.ar_audio_embedding(session.y[:, -1:])
         | 
| 321 | 
            -
                                    session.xy_pos = decoder.ar_audio_position.forward( | 
| 322 |  | 
| 323 | 
             
                                if idx == 2:
         | 
| 324 | 
             
                                    torch_profiler.start()
         | 
| 325 | 
             
                                    t1 = time.perf_counter()
         | 
| 326 |  | 
| 327 | 
            -
                                 | 
| 328 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 329 |  | 
| 330 | 
             
                        match session.device.type:
         | 
| 331 | 
             
                            case "cuda":
         | 
| @@ -336,7 +360,7 @@ class CUDAGraphRunner: | |
| 336 | 
             
                                torch.xpu.empty_cache()
         | 
| 337 | 
             
                            case "mtia":
         | 
| 338 | 
             
                                torch.mtia.empty_cache()
         | 
| 339 | 
            -
             | 
| 340 | 
             
                        torch_profiler.end()
         | 
| 341 | 
             
                        return session.y_results[: request.valid_length]
         | 
| 342 |  | 
|  | |
| 1 | 
            +
            """
         | 
| 2 | 
            +
            Modified From https://github.com/XXXXRT666/GPT-SoVITS
         | 
| 3 | 
            +
            """
         | 
| 4 | 
            +
             | 
| 5 | 
             
            import os
         | 
| 6 | 
             
            import time
         | 
| 7 | 
             
            import traceback
         | 
| 8 | 
            +
            from typing import Dict, List, Tuple,Optional
         | 
| 9 | 
            +
            import gradio as gr
         | 
| 10 |  | 
| 11 | 
             
            import flash_attn  # type: ignore
         | 
| 12 | 
             
            import torch
         | 
|  | |
| 54 |  | 
| 55 | 
             
                    attn: Tensor = flash_attn.flash_attn_with_kvcache(
         | 
| 56 | 
             
                        q, kv_cache.k_cache, kv_cache.v_cache, k, v, cache_seqlens=input_pos - 1
         | 
| 57 | 
            +
                    ) # type: ignore
         | 
| 58 |  | 
| 59 | 
             
                    attn = self.dropout.forward(attn)
         | 
| 60 |  | 
|  | |
| 219 |  | 
| 220 | 
             
                    self.decoder_path: os.PathLike
         | 
| 221 | 
             
                    self.decoder_model: T2SDecoderABC = decoder_model.to(self.device, self.dtype)
         | 
| 222 | 
            +
                    
         | 
| 223 | 
            +
                    self.graph: Optional[torch.cuda.CUDAGraph]= None
         | 
| 224 | 
            +
                    self.xy_pos_ = torch.rand((1, 1, decoder_model.embedding_dim),device=device).to(dtype)
         | 
| 225 | 
            +
                    self.xy_dec_ = torch.rand((1, 1, decoder_model.embedding_dim),device=device).to(dtype)
         | 
| 226 | 
            +
                    self.kv_cache = decoder_model.init_cache(1)
         | 
| 227 | 
            +
                    self.input_pos = torch.tensor([10]).int().cuda()
         | 
| 228 |  | 
| 229 | 
             
                def _handle_request(self, request: T2SRequest) -> List[torch.Tensor]:
         | 
| 230 | 
             
                    with self.device:
         | 
| 231 | 
            +
                        for i in self.kv_cache:
         | 
| 232 | 
            +
                            i.empty()
         | 
| 233 | 
            +
                            
         | 
| 234 | 
             
                        decoder = self.decoder_model
         | 
| 235 | 
             
                        session = T2SSession(decoder, request, device=self.device, dtype=self.dtype)
         | 
| 236 | 
            +
                        self.input_pos.copy_(session.input_pos)
         | 
| 237 | 
            +
                        
         | 
|  | |
| 238 | 
             
                        t1 = 0.0
         | 
| 239 | 
            +
                        y = session.y 
         | 
| 240 | 
            +
                        bsz = y.size(0)
         | 
| 241 | 
             
                        torch_profiler = TorchProfiler(request.debug)
         | 
|  | |
| 242 | 
             
                        with torch_profiler.profiler():
         | 
| 243 | 
             
                            for idx in tqdm(range(1500)):
         | 
| 244 | 
             
                                if idx == 0:
         | 
| 245 | 
            +
                                    xy_dec = decoder.h.prefill(session.xy_pos, session.attn_mask_nested, self.kv_cache)
         | 
| 246 | 
             
                                    xy_dec = torch.stack([t[[-1]] for t in xy_dec.unbind()])
         | 
| 247 | 
             
                                else:
         | 
| 248 | 
            +
                                    if request.use_cuda_graph and self.graph is None and torch.cuda.is_available():
         | 
| 249 | 
            +
                                        self.xy_pos_.copy_(session.xy_pos)
         | 
| 250 | 
             
                                        args, kwds = decoder.pre_forward(session)
         | 
| 251 | 
            +
                                        self.graph = decoder.capture(
         | 
| 252 | 
            +
                                            self.input_pos,
         | 
| 253 | 
            +
                                            self.xy_pos_,
         | 
| 254 | 
            +
                                            self.xy_dec_,
         | 
| 255 | 
            +
                                            kv_caches=self.kv_cache,
         | 
| 256 | 
             
                                            *args,
         | 
| 257 | 
             
                                            **kwds,
         | 
| 258 | 
             
                                        )
         | 
| 259 |  | 
| 260 | 
             
                                    with torch_profiler.record("AR"):
         | 
| 261 | 
            +
                                        if self.graph:
         | 
| 262 | 
            +
                                            self.xy_pos_.copy_(session.xy_pos)
         | 
| 263 | 
            +
                                            self.graph.replay()
         | 
| 264 | 
            +
                                            xy_dec = self.xy_dec_.clone()
         | 
| 265 | 
             
                                        else:
         | 
| 266 | 
             
                                            args, kwds = decoder.pre_forward(session)
         | 
| 267 | 
             
                                            xy_dec = decoder.h.forward(
         | 
| 268 | 
            +
                                                self.input_pos,
         | 
| 269 | 
             
                                                session.xy_pos,
         | 
| 270 | 
            +
                                                self.kv_cache,
         | 
| 271 | 
             
                                                *args,
         | 
| 272 | 
             
                                                **kwds,
         | 
| 273 | 
             
                                            )
         | 
| 274 | 
            +
                                    
         | 
| 275 | 
             
                                decoder.post_forward(idx, session)
         | 
| 276 | 
             
                                logits = decoder.ar_predict_layer(xy_dec[:, -1])
         | 
| 277 | 
            +
                                self.input_pos.add_(1)
         | 
| 278 |  | 
| 279 | 
             
                                if idx == 0:
         | 
| 280 | 
            +
                                    logits[:, -1] = float("-inf")
         | 
| 281 | 
            +
                                
         | 
| 282 | 
             
                                with torch_profiler.record("Sampling"):
         | 
| 283 | 
             
                                    samples = session.sampler.sample(
         | 
| 284 | 
             
                                        logits=logits,
         | 
|  | |
| 287 | 
             
                                        top_p=request.top_p,
         | 
| 288 | 
             
                                        repetition_penalty=request.repetition_penalty,
         | 
| 289 | 
             
                                        temperature=request.temperature,
         | 
|  | |
|  | |
| 290 | 
             
                                    )
         | 
| 291 |  | 
| 292 | 
             
                                    session.y = torch.cat([session.y, samples], dim=1)
         | 
| 293 |  | 
| 294 | 
            +
             | 
| 295 | 
             
                                with torch_profiler.record("EOS"):
         | 
| 296 | 
             
                                    argmax_token = torch.argmax(logits, dim=-1)
         | 
| 297 | 
             
                                    sample_token = samples.squeeze(1)
         | 
| 298 | 
             
                                    EOS_mask = (argmax_token == decoder.EOS) | (sample_token == decoder.EOS)
         | 
| 299 | 
            +
                                    
         | 
| 300 | 
             
                                    newly_done_mask = EOS_mask & (~session.completed)
         | 
|  | |
| 301 | 
             
                                    newly_done_indices = newly_done_mask.nonzero()
         | 
| 302 | 
            +
                                    
         | 
| 303 | 
            +
                                    
         | 
| 304 | 
             
                                    if newly_done_indices.numel() > 0:
         | 
| 305 | 
             
                                        session.y_results[newly_done_indices[0]] = session.y[
         | 
| 306 | 
             
                                            newly_done_indices[0], session.y_len : -1
         | 
| 307 | 
             
                                        ].squeeze(0)
         | 
| 308 | 
             
                                        session.completed[newly_done_indices] = True
         | 
| 309 | 
            +
                                        
         | 
| 310 | 
             
                                    if torch.all(session.completed).item():
         | 
| 311 | 
             
                                        if session.y.size(1) == 0:
         | 
| 312 | 
             
                                            session.y = torch.cat([session.y, torch.zeros_like(samples)], dim=1)
         | 
|  | |
| 316 | 
             
                                                f"T2S Decoding EOS {session.prefill_len.tolist().__str__().strip('[]')} -> \n{[i.size(0) for i in session.y_results].__str__().strip('[]')}"
         | 
| 317 | 
             
                                            )
         | 
| 318 | 
             
                                            tqdm.write(f"Infer Speed: {(idx - 1) / (time.perf_counter() - t1):.2f} token/s")
         | 
| 319 | 
            +
                                            gr.Info(f"Infer Speed: {(idx - 1) / (time.perf_counter() - t1):.2f} token/s",duration=0.75)
         | 
| 320 | 
             
                                        break
         | 
| 321 | 
            +
                                    
         | 
| 322 | 
             
                                    if (
         | 
| 323 | 
            +
                                        (request.early_stop_num != -1
         | 
| 324 | 
            +
                                        and (session.y.size(1) - session.y_len) > request.early_stop_num )or idx ==1499
         | 
| 325 | 
             
                                    ):
         | 
| 326 | 
             
                                        for i in range(bsz):
         | 
| 327 | 
             
                                            if not session.completed[i].item():
         | 
|  | |
| 331 |  | 
| 332 | 
             
                                with torch_profiler.record("NextPos"):
         | 
| 333 | 
             
                                    y_emb = decoder.ar_audio_embedding(session.y[:, -1:])
         | 
| 334 | 
            +
                                    session.xy_pos = decoder.ar_audio_position.forward(self.input_pos - session.x_lens, y_emb)
         | 
| 335 |  | 
| 336 | 
             
                                if idx == 2:
         | 
| 337 | 
             
                                    torch_profiler.start()
         | 
| 338 | 
             
                                    t1 = time.perf_counter()
         | 
| 339 |  | 
| 340 | 
            +
                                if idx == 51:
         | 
| 341 | 
            +
                                    torch_profiler.end()
         | 
| 342 | 
            +
                                    
         | 
| 343 | 
            +
                                if idx % 100 == 0:
         | 
| 344 | 
            +
                                    match session.device.type:
         | 
| 345 | 
            +
                                        case "cuda":
         | 
| 346 | 
            +
                                            torch.cuda.empty_cache()
         | 
| 347 | 
            +
                                        case "mps":
         | 
| 348 | 
            +
                                            torch.mps.empty_cache()
         | 
| 349 | 
            +
                                        case "xpu":
         | 
| 350 | 
            +
                                            torch.xpu.empty_cache()
         | 
| 351 | 
            +
                                        case "mtia":
         | 
| 352 | 
            +
                                            torch.mtia.empty_cache()
         | 
| 353 |  | 
| 354 | 
             
                        match session.device.type:
         | 
| 355 | 
             
                            case "cuda":
         | 
|  | |
| 360 | 
             
                                torch.xpu.empty_cache()
         | 
| 361 | 
             
                            case "mtia":
         | 
| 362 | 
             
                                torch.mtia.empty_cache()
         | 
| 363 | 
            +
                                
         | 
| 364 | 
             
                        torch_profiler.end()
         | 
| 365 | 
             
                        return session.y_results[: request.valid_length]
         | 
| 366 |  | 
    	
        README.md
    CHANGED
    
    | @@ -4,7 +4,7 @@ emoji: π€ | |
| 4 | 
             
            colorFrom: indigo
         | 
| 5 | 
             
            colorTo: red
         | 
| 6 | 
             
            sdk: gradio
         | 
| 7 | 
            -
            sdk_version:  | 
| 8 | 
             
            app_file: inference_webui.py
         | 
| 9 | 
             
            pinned: false
         | 
| 10 | 
             
            license: mit
         | 
|  | |
| 4 | 
             
            colorFrom: indigo
         | 
| 5 | 
             
            colorTo: red
         | 
| 6 | 
             
            sdk: gradio
         | 
| 7 | 
            +
            sdk_version: 5.20.0
         | 
| 8 | 
             
            app_file: inference_webui.py
         | 
| 9 | 
             
            pinned: false
         | 
| 10 | 
             
            license: mit
         | 
    	
        inference_webui.py
    CHANGED
    
    | @@ -57,6 +57,10 @@ import LangSegment | |
| 57 | 
             
            import spaces
         | 
| 58 | 
             
            import torch
         | 
| 59 |  | 
|  | |
|  | |
|  | |
|  | |
| 60 | 
             
            version = "v2"  # os.environ.get("version","v2")
         | 
| 61 | 
             
            cnhubert_base_path = os.environ.get("cnhubert_base_path", "pretrained_models/chinese-hubert-base")
         | 
| 62 | 
             
            bert_path = os.environ.get("bert_path", "pretrained_models/chinese-roberta-wwm-ext-large")
         | 
| @@ -540,7 +544,7 @@ def get_tts_wav( | |
| 540 | 
             
                    if i_text in cache and if_freeze == True:
         | 
| 541 | 
             
                        pred_semantic = cache[i_text]
         | 
| 542 | 
             
                    else:
         | 
| 543 | 
            -
                        with torch.no_grad():
         | 
| 544 | 
             
                            t2s_request = T2SRequest(
         | 
| 545 | 
             
                                [all_phoneme_ids.squeeze(0)],
         | 
| 546 | 
             
                                all_phoneme_len,
         | 
| @@ -552,7 +556,7 @@ def get_tts_wav( | |
| 552 | 
             
                                temperature=temperature,
         | 
| 553 | 
             
                                early_stop_num=1500,
         | 
| 554 | 
             
                                use_cuda_graph=True,
         | 
| 555 | 
            -
                                debug=True,
         | 
| 556 | 
             
                            )
         | 
| 557 | 
             
                            t2s_result = t2s_model.generate(t2s_request)
         | 
| 558 | 
             
                            pred_semantic = t2s_result.result
         | 
| @@ -836,5 +840,4 @@ if __name__ == "__main__": | |
| 836 | 
             
                    server_name="0.0.0.0",
         | 
| 837 | 
             
                    inbrowser=True,
         | 
| 838 | 
             
                    show_api=False,
         | 
| 839 | 
            -
                    server_port=1111,
         | 
| 840 | 
             
                )
         | 
|  | |
| 57 | 
             
            import spaces
         | 
| 58 | 
             
            import torch
         | 
| 59 |  | 
| 60 | 
            +
            import threading
         | 
| 61 | 
            +
             | 
| 62 | 
            +
            lock = threading.Lock()
         | 
| 63 | 
            +
             | 
| 64 | 
             
            version = "v2"  # os.environ.get("version","v2")
         | 
| 65 | 
             
            cnhubert_base_path = os.environ.get("cnhubert_base_path", "pretrained_models/chinese-hubert-base")
         | 
| 66 | 
             
            bert_path = os.environ.get("bert_path", "pretrained_models/chinese-roberta-wwm-ext-large")
         | 
|  | |
| 544 | 
             
                    if i_text in cache and if_freeze == True:
         | 
| 545 | 
             
                        pred_semantic = cache[i_text]
         | 
| 546 | 
             
                    else:
         | 
| 547 | 
            +
                        with torch.no_grad(),lock:
         | 
| 548 | 
             
                            t2s_request = T2SRequest(
         | 
| 549 | 
             
                                [all_phoneme_ids.squeeze(0)],
         | 
| 550 | 
             
                                all_phoneme_len,
         | 
|  | |
| 556 | 
             
                                temperature=temperature,
         | 
| 557 | 
             
                                early_stop_num=1500,
         | 
| 558 | 
             
                                use_cuda_graph=True,
         | 
| 559 | 
            +
                                # debug=True,
         | 
| 560 | 
             
                            )
         | 
| 561 | 
             
                            t2s_result = t2s_model.generate(t2s_request)
         | 
| 562 | 
             
                            pred_semantic = t2s_result.result
         | 
|  | |
| 840 | 
             
                    server_name="0.0.0.0",
         | 
| 841 | 
             
                    inbrowser=True,
         | 
| 842 | 
             
                    show_api=False,
         | 
|  | |
| 843 | 
             
                )
         |