Spaces:
Runtime error
Runtime error
wenmengzhou
commited on
update model according to hysts advice
Browse files- cosyvoice/cli/model.py +4 -6
cosyvoice/cli/model.py
CHANGED
@@ -19,18 +19,17 @@ class CosyVoiceModel:
|
|
19 |
llm: torch.nn.Module,
|
20 |
flow: torch.nn.Module,
|
21 |
hift: torch.nn.Module):
|
22 |
-
|
23 |
-
self.device = 'cpu'
|
24 |
self.llm = llm
|
25 |
self.flow = flow
|
26 |
self.hift = hift
|
27 |
|
28 |
def load(self, llm_model, flow_model, hift_model):
|
29 |
-
self.llm.load_state_dict(torch.load(llm_model, map_location=
|
30 |
self.llm.to(self.device).eval()
|
31 |
-
self.flow.load_state_dict(torch.load(flow_model, map_location=
|
32 |
self.flow.to(self.device).eval()
|
33 |
-
self.hift.load_state_dict(torch.load(hift_model, map_location=
|
34 |
self.hift.to(self.device).eval()
|
35 |
|
36 |
def inference(self, text, text_len, flow_embedding, llm_embedding=torch.zeros(0, 192),
|
@@ -38,7 +37,6 @@ class CosyVoiceModel:
|
|
38 |
llm_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32), llm_prompt_speech_token_len=torch.zeros(1, dtype=torch.int32),
|
39 |
flow_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32), flow_prompt_speech_token_len=torch.zeros(1, dtype=torch.int32),
|
40 |
prompt_speech_feat=torch.zeros(1, 0, 80), prompt_speech_feat_len=torch.zeros(1, dtype=torch.int32)):
|
41 |
-
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
42 |
tts_speech_token = self.llm.inference(text=text.to(self.device),
|
43 |
text_len=text_len.to(self.device),
|
44 |
prompt_text=prompt_text.to(self.device),
|
|
|
19 |
llm: torch.nn.Module,
|
20 |
flow: torch.nn.Module,
|
21 |
hift: torch.nn.Module):
|
22 |
+
self.device = torch.device('cuda')
|
|
|
23 |
self.llm = llm
|
24 |
self.flow = flow
|
25 |
self.hift = hift
|
26 |
|
27 |
def load(self, llm_model, flow_model, hift_model):
|
28 |
+
self.llm.load_state_dict(torch.load(llm_model, map_location='cpu'))
|
29 |
self.llm.to(self.device).eval()
|
30 |
+
self.flow.load_state_dict(torch.load(flow_model, map_location='cpu'))
|
31 |
self.flow.to(self.device).eval()
|
32 |
+
self.hift.load_state_dict(torch.load(hift_model, map_location='cpu'))
|
33 |
self.hift.to(self.device).eval()
|
34 |
|
35 |
def inference(self, text, text_len, flow_embedding, llm_embedding=torch.zeros(0, 192),
|
|
|
37 |
llm_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32), llm_prompt_speech_token_len=torch.zeros(1, dtype=torch.int32),
|
38 |
flow_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32), flow_prompt_speech_token_len=torch.zeros(1, dtype=torch.int32),
|
39 |
prompt_speech_feat=torch.zeros(1, 0, 80), prompt_speech_feat_len=torch.zeros(1, dtype=torch.int32)):
|
|
|
40 |
tts_speech_token = self.llm.inference(text=text.to(self.device),
|
41 |
text_len=text_len.to(self.device),
|
42 |
prompt_text=prompt_text.to(self.device),
|