FrankZxShen commited on
Commit
0f66f70
1 Parent(s): 29755d6

Update utils.py

Browse files
Files changed (1) hide show
  1. utils.py +0 -87
utils.py CHANGED
@@ -10,8 +10,6 @@ from scipy.io.wavfile import read
10
  import torch
11
  import regex as re
12
 
13
- import loralib as lora
14
-
15
  MATPLOTLIB_FLAG = False
16
 
17
  logging.basicConfig(stream=sys.stdout, level=logging.DEBUG)
@@ -146,91 +144,6 @@ def tag_cke(text,prev_sentence=None):
146
  return prev_lang,tagged_text
147
 
148
 
149
- def load_lora_checkpoint(checkpoint_path, model, optimizer=None, generator_path = "./pretrained_models/G_latest.pth"):
150
- assert os.path.isfile(checkpoint_path)
151
- checkpoint_dict = torch.load(checkpoint_path, map_location='cpu')
152
- iteration = checkpoint_dict['iteration']
153
- learning_rate = checkpoint_dict['learning_rate']
154
- if optimizer is not None:
155
- optimizer.load_state_dict(checkpoint_dict['optimizer'])
156
- # saved_state_dict = checkpoint_dict['model']
157
- generator_state_dict = torch.load(generator_path)['model']
158
- lora_state_dict = checkpoint_dict['model']
159
- new_state_dict = {}
160
- for k, v in lora_state_dict.items():
161
- try:
162
- if k == 'emb_g.weight':
163
- if drop_speaker_emb:
164
- new_state_dict[k] = v
165
- continue
166
- v[:lora_state_dict[k].shape[0], :] = lora_state_dict[k]
167
- new_state_dict[k] = v
168
- else:
169
- new_state_dict[k] = lora_state_dict[k]
170
- except:
171
- logger.info("%s is not in the checkpoint" % k)
172
- new_state_dict[k] = v
173
- if hasattr(model, 'module'):
174
- model.module.load_state_dict(generator_state_dict, strict=False)
175
- model.module.load_state_dict(new_state_dict, strict=False)
176
- # lora.mark_only_lora_as_trainable(model.module)
177
- else:
178
- model.load_state_dict(generator_state_dict, strict=False)
179
- model.load_state_dict(new_state_dict, strict=False)
180
- # lora.mark_only_lora_as_trainable(model)
181
- logger.info("Loaded checkpoint '{}' (iteration {})" .format(
182
- checkpoint_path, iteration))
183
-
184
- return model, optimizer, learning_rate, iteration
185
-
186
- def save_lora_checkpoint(model, optimizer, learning_rate, iteration, checkpoint_path):
187
- logger.info("Saving model and optimizer state at iteration {} to {}".format(
188
- iteration, checkpoint_path))
189
- if hasattr(model, 'module'):
190
- state_dict = lora.lora_state_dict(model.module)
191
- else:
192
- state_dict = lora.lora_state_dict(model)
193
- torch.save({'model': state_dict,
194
- 'iteration': iteration,
195
- 'optimizer': optimizer.state_dict() if optimizer is not None else None,
196
- 'learning_rate': learning_rate}, checkpoint_path)
197
-
198
- def load_lora_checkpoint_fix(checkpoint_path, model, optimizer=None, drop_speaker_emb=False):
199
- assert os.path.isfile(checkpoint_path)
200
- checkpoint_dict = torch.load(checkpoint_path, map_location='cpu')
201
- iteration = checkpoint_dict['iteration']
202
- learning_rate = checkpoint_dict['learning_rate']
203
- if optimizer is not None:
204
- optimizer.load_state_dict(checkpoint_dict['optimizer'])
205
- saved_state_dict = checkpoint_dict['model']
206
- if hasattr(model, 'module'):
207
- state_dict = model.module.state_dict()
208
- else:
209
- state_dict = model.state_dict()
210
- new_state_dict = {}
211
- for k, v in state_dict.items():
212
- try:
213
- if k == 'emb_g.weight':
214
- if drop_speaker_emb:
215
- new_state_dict[k] = v
216
- continue
217
- v[:saved_state_dict[k].shape[0], :] = saved_state_dict[k]
218
- new_state_dict[k] = v
219
- else:
220
- new_state_dict[k] = saved_state_dict[k]
221
- except:
222
- logger.info("%s is not in the checkpoint" % k)
223
- new_state_dict[k] = v
224
- if hasattr(model, 'module'):
225
- model.module.load_state_dict(new_state_dict, strict=False)
226
- lora.mark_only_lora_as_trainable(model.module)
227
- else:
228
- model.load_state_dict(new_state_dict, strict=False)
229
- lora.mark_only_lora_as_trainable(model)
230
- logger.info("Loaded checkpoint '{}' (iteration {})".format(
231
- checkpoint_path, iteration))
232
- return model, optimizer, learning_rate, iteration
233
-
234
  def load_checkpoint(checkpoint_path, model, optimizer=None, drop_speaker_emb=False):
235
  assert os.path.isfile(checkpoint_path)
236
  checkpoint_dict = torch.load(checkpoint_path, map_location='cpu')
 
10
  import torch
11
  import regex as re
12
 
 
 
13
  MATPLOTLIB_FLAG = False
14
 
15
  logging.basicConfig(stream=sys.stdout, level=logging.DEBUG)
 
144
  return prev_lang,tagged_text
145
 
146
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
147
  def load_checkpoint(checkpoint_path, model, optimizer=None, drop_speaker_emb=False):
148
  assert os.path.isfile(checkpoint_path)
149
  checkpoint_dict = torch.load(checkpoint_path, map_location='cpu')