ing0 commited on
Commit
ff46889
·
1 Parent(s): 8d10dcc
Files changed (3) hide show
  1. app.py +5 -1
  2. diffrhythm/model/cfm.py +10 -16
  3. diffrhythm/model/dit.py +7 -3
app.py CHANGED
@@ -232,8 +232,12 @@ with gr.Blocks(css=css) as demo:
232
  3. **Supported Languages**
233
  - **Chinese and English**
234
  - More languages comming soon
 
 
 
 
235
 
236
- 4. **Others**
237
  - If loading audio result is slow, you can select Output Format as mp3 in Advanced Settings.
238
 
239
  """)
 
232
  3. **Supported Languages**
233
  - **Chinese and English**
234
  - More languages comming soon
235
+
236
+ 4. **Editing Function in Advanced Settings**
237
+ - Using full-length audio as reference is recommended for best results.
238
+ - Use -1 to represent the start/end of audio (e.g. [[-1,25], [50,-1]] means "from start to 25s" and "from 50s to end").
239
 
240
+ 5. **Others**
241
  - If loading audio result is slow, you can select Output Format as mp3 in Advanced Settings.
242
 
243
  """)
diffrhythm/model/cfm.py CHANGED
@@ -208,27 +208,21 @@ class CFM(nn.Module):
208
  negative_style_prompt = negative_style_prompt.repeat(batch_infer_num, 1)
209
  start_time = start_time.repeat(batch_infer_num)
210
  fixed_span_mask = fixed_span_mask.repeat(batch_infer_num, 1, 1)
211
-
212
- start_time_embed, positive_text_embed, positive_text_residuals = self.transformer.forward_timestep_invariant(text, step_cond.shape[1], drop_text=False, start_time=start_time)
213
- _, negative_text_embed, negative_text_residuals = self.transformer.forward_timestep_invariant(text, step_cond.shape[1], drop_text=True, start_time=start_time)
214
-
215
- text_embed = torch.cat([positive_text_embed, negative_text_embed], 0)
216
- text_residuals = [torch.cat([a, b], 0) for a, b in zip(positive_text_residuals, negative_text_residuals)]
217
- step_cond = torch.cat([step_cond, step_cond], 0)
218
- style_prompt = torch.cat([style_prompt, negative_style_prompt], 0)
219
- start_time_embed = torch.cat([start_time_embed, start_time_embed], 0)
220
 
221
  def fn(t, x):
222
- x = torch.cat([x, x], 0)
223
  pred = self.transformer(
224
- x=x, text_embed=text_embed, text_residuals=text_residuals, cond=step_cond, time=t,
225
- drop_audio_cond=True, drop_prompt=False, style_prompt=style_prompt, start_time=start_time_embed
226
  )
 
 
227
 
228
- positive_pred, negative_pred = pred.chunk(2, 0)
229
- cfg_pred = positive_pred + (positive_pred - negative_pred) * cfg_strength
230
-
231
- return cfg_pred
 
232
 
233
  # noise input
234
  # to make sure batch inference result is same with different batch size, and for sure single inference
 
208
  negative_style_prompt = negative_style_prompt.repeat(batch_infer_num, 1)
209
  start_time = start_time.repeat(batch_infer_num)
210
  fixed_span_mask = fixed_span_mask.repeat(batch_infer_num, 1, 1)
 
 
 
 
 
 
 
 
 
211
 
212
  def fn(t, x):
213
+ # predict flow
214
  pred = self.transformer(
215
+ x=x, cond=step_cond, text=text, time=t, drop_audio_cond=False, drop_text=False, drop_prompt=False,
216
+ style_prompt=style_prompt, start_time=start_time
217
  )
218
+ if cfg_strength < 1e-5:
219
+ return pred
220
 
221
+ null_pred = self.transformer(
222
+ x=x, cond=step_cond, text=text, time=t, drop_audio_cond=True, drop_text=True, drop_prompt=False,
223
+ style_prompt=negative_style_prompt, start_time=start_time
224
+ )
225
+ return pred + (pred - null_pred) * cfg_strength
226
 
227
  # noise input
228
  # to make sure batch inference result is same with different batch size, and for sure single inference
diffrhythm/model/dit.py CHANGED
@@ -162,21 +162,25 @@ class DiT(nn.Module):
162
  def forward(
163
  self,
164
  x: float["b n d"], # nosied input audio # noqa: F722
165
- text_embed: int["b nt"], # text # noqa: F722
166
- text_residuals,
167
  cond: float["b n d"], # masked cond audio # noqa: F722
 
168
  time: float["b"] | float[""], # time step # noqa: F821 F722
169
  drop_audio_cond, # cfg for cond audio
 
170
  drop_prompt=False,
171
  style_prompt=None, # [b d t]
172
  start_time=None,
173
  ):
 
174
  batch, seq_len = x.shape[0], x.shape[1]
175
  if time.ndim == 0:
176
  time = time.repeat(batch)
177
 
 
178
  t = self.time_embed(time)
179
- c = t + start_time
 
 
180
 
181
  if drop_prompt:
182
  style_prompt = torch.zeros_like(style_prompt)
 
162
  def forward(
163
  self,
164
  x: float["b n d"], # nosied input audio # noqa: F722
 
 
165
  cond: float["b n d"], # masked cond audio # noqa: F722
166
+ text: int["b nt"], # text # noqa: F722
167
  time: float["b"] | float[""], # time step # noqa: F821 F722
168
  drop_audio_cond, # cfg for cond audio
169
+ drop_text, # cfg for text
170
  drop_prompt=False,
171
  style_prompt=None, # [b d t]
172
  start_time=None,
173
  ):
174
+
175
  batch, seq_len = x.shape[0], x.shape[1]
176
  if time.ndim == 0:
177
  time = time.repeat(batch)
178
 
179
+ # t: conditioning time, c: context (text + masked cond audio), x: noised input audio
180
  t = self.time_embed(time)
181
+ s_t = self.start_time_embed(start_time)
182
+ c = t + s_t
183
+ text_embed = self.text_embed(text, seq_len, drop_text=drop_text)
184
 
185
  if drop_prompt:
186
  style_prompt = torch.zeros_like(style_prompt)