Spaces:
Runtime error
Runtime error
XXXXRT666
commited on
Commit
Β·
3011ece
1
Parent(s):
4ae2215
- AR/models/t2s_model_abc.py +3 -1
- AR/models/t2s_model_flash_attn.py +1 -1
- inference_webui.py +1 -0
AR/models/t2s_model_abc.py
CHANGED
|
@@ -449,7 +449,9 @@ class CUDAGraphCacheABC(ABC):
|
|
| 449 |
def assign_graph(self, session: Any):
|
| 450 |
if self.graph is None:
|
| 451 |
args, kwds = self.decoder.pre_forward(session)
|
| 452 |
-
graph = self.decoder.capture(
|
|
|
|
|
|
|
| 453 |
self.graph = graph
|
| 454 |
|
| 455 |
if self.assigned is False:
|
|
|
|
| 449 |
def assign_graph(self, session: Any):
|
| 450 |
if self.graph is None:
|
| 451 |
args, kwds = self.decoder.pre_forward(session)
|
| 452 |
+
graph = self.decoder.capture(
|
| 453 |
+
self.input_pos, self.xy_pos, self.xy_dec, kv_caches=self.kv_cache, *args, **kwds
|
| 454 |
+
)
|
| 455 |
self.graph = graph
|
| 456 |
|
| 457 |
if self.assigned is False:
|
AR/models/t2s_model_flash_attn.py
CHANGED
|
@@ -239,7 +239,7 @@ class CUDAGraphCache(CUDAGraphCacheABC):
|
|
| 239 |
session.input_pos = self.input_pos.clone().copy_(session.input_pos)
|
| 240 |
|
| 241 |
args, kwds = self.decoder.pre_forward(session)
|
| 242 |
-
graph = self.decoder.capture(self.input_pos, self.xy_pos, self.xy_dec, *args, **kwds)
|
| 243 |
session.graph = graph
|
| 244 |
|
| 245 |
|
|
|
|
| 239 |
session.input_pos = self.input_pos.clone().copy_(session.input_pos)
|
| 240 |
|
| 241 |
args, kwds = self.decoder.pre_forward(session)
|
| 242 |
+
graph = self.decoder.capture(self.input_pos, self.xy_pos, self.xy_dec, kv_caches=self.kv_cache, *args, **kwds)
|
| 243 |
session.graph = graph
|
| 244 |
|
| 245 |
|
inference_webui.py
CHANGED
|
@@ -38,6 +38,7 @@ logging.getLogger("torchaudio._extension").setLevel(logging.ERROR)
|
|
| 38 |
logging.getLogger("multipart.multipart").setLevel(logging.ERROR)
|
| 39 |
logging.getLogger("python_multipart.multipart").setLevel(logging.ERROR)
|
| 40 |
logging.getLogger("split_lang.split.splitter").setLevel(logging.ERROR)
|
|
|
|
| 41 |
|
| 42 |
os.makedirs("pretrained_models", exist_ok=True)
|
| 43 |
|
|
|
|
| 38 |
logging.getLogger("multipart.multipart").setLevel(logging.ERROR)
|
| 39 |
logging.getLogger("python_multipart.multipart").setLevel(logging.ERROR)
|
| 40 |
logging.getLogger("split_lang.split.splitter").setLevel(logging.ERROR)
|
| 41 |
+
logging.getLogger("filelock").setLevel(logging.INFO)
|
| 42 |
|
| 43 |
os.makedirs("pretrained_models", exist_ok=True)
|
| 44 |
|